Compare commits

...

89 Commits

Author SHA1 Message Date
Isotr0py
5506435419 [Misc] Clean up Gemma4 implementation (#38872)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
2026-04-03 05:47:02 +00:00
Yifan Qiao
311c981647 [MRV2][KVConnector] Fix missing build_connector_worker_meta (#38698)
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
2026-04-03 08:42:52 +03:00
Li, Jiang
21d7ecc5b0 [CI/Build] Add audio deps in Dockerfile.cpu (#38876)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
2026-04-03 05:05:14 +00:00
Aaron Hao
4729b90838 [Bug] Add e_score_correction_bias to SKIP_TENSORS (#38746)
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
2026-04-02 21:15:05 -07:00
shunting314
8b141ed8c3 full cudagraph for flex-attn (#36298)
Signed-off-by: shunting314 <shunting@meta.com>
2026-04-02 21:15:01 -07:00
Varun Sundar Rabindranath
2ad7c0335f [Model] Add Phi4ForCausalLMV for microsoft/Phi-4-reasoning-vision-15B (#38306)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
2026-04-02 21:14:57 -07:00
Bowen Bao
201d2ea5bf [CI][ROCm] Add Qwen3.5-35B-A3B-MXFP4 model eval into CI (#38664)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
2026-04-03 04:05:45 +00:00
Bowen Bao
103f0de565 [ROCm][Quantization][1/N] Refactor quark_moe w_mxfp4 w/ oracle (#38774)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
2026-04-03 03:29:57 +00:00
wliao2
32e0c0bfa2 refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
2026-04-03 11:21:47 +08:00
Itay Etelis
4a06e1246e [Perf] Batch KV cache swap copies via cuMemcpyBatchAsync (#38460)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
2026-04-03 03:13:23 +00:00
Carl Y
3bc2734dd0 [Kernel] Fuse FP8 output quantization into merge_attn_states (#36518)
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
2026-04-03 01:47:04 +00:00
Carl Y
1f5ec2889c [mla] Support fused FP8/NVFP4 output quantization in MLA attention (#35792) (#36205)
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 21:16:11 -04:00
Yan Ma
ee3cf45739 [XPU] Initial support for GDN attention on Qwen3-next/Qwen3.5 (#33657)
Signed-off-by: Yan Ma <yan.ma@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
2026-04-03 08:59:11 +08:00
Matthew Bonanni
05e68e1f81 [CI] Fix test_nixl_connector (#38838) 2026-04-02 17:52:13 -07:00
Vadim Gimpelson
771913e4a0 [Bugfix] Fix NVFP4+MTP crash: force unquantized mtp.fc for Qwen3.5 (#38832)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
2026-04-03 04:45:57 +04:00
1096125073
71a9125c67 [New Model]: add support for telechat3 (#38510)
Signed-off-by: xiayongqiang <xiayq1@chinatelecom.cn>
Co-authored-by: xiayongqiang <xiayq1@chinatelecom.cn>
2026-04-03 08:26:22 +08:00
Nicolò Lucchesi
66e86f1dbd [Kernel] Mamba support different layout for Conv state (#37416) 2026-04-03 01:50:09 +02:00
Michael
bb39382b2b [Bugfix]: Fix Gemma4ToolParser.__init__() missing tools parameter (#38847)
Signed-off-by: Michael Hospedales <hospedales@me.com>
2026-04-02 14:35:19 -07:00
zhanqiuhu
7b743ba953 [CI] Fix: pass string cache_dtype in test_register_kv_caches (#38836) 2026-04-02 19:42:09 +00:00
Stefano Castagnetta
188defbd0b [CI] Add flashinfer.py to attention test source deps (#38792)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2026-04-02 19:24:29 +00:00
Luciano Martins
08ed2b9688 feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Luciano Martins <lucianomartins@google.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
2026-04-02 11:13:28 -07:00
Yanan Cao
ecd5443dbc Bump helion dependency from 0.3.2 to 0.3.3 (#38062)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 10:59:33 -07:00
Stefano Castagnetta
58262dec6e [Bugfix] Fix test mocks after SM100 restriction in #38730 (#38791)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: Claude <noreply@anthropic.com>
2026-04-02 13:12:58 -04:00
Lucas Wilkinson
cb3935a8fc [FA4] Update flash-attention to latest upstream FA4 (#38690)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2026-04-02 17:02:37 +00:00
Bowen Bao
82a006beeb [CI][ROCm] Add gpt-oss w4a8 in CI (#38292)
Signed-off-by: Bowen Bao <bowenbao@amd.com>
2026-04-03 00:06:01 +08:00
wang.yuqi
a9b4f07ba2 [Frontend] Re-enable running MaxSim on GPU (#38620)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
2026-04-03 00:03:13 +08:00
Koushik Dutta
d9408ffba3 Triton MLA perf fixes (#33529)
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2026-04-02 09:40:01 -04:00
Yusuf Mohammad
16a65e4173 [Bugfix] Enable batch-invariant Triton matmul on all Ampere GPUs (SM 8x) (#38427)
Signed-off-by: yusuf <yusufmohammad@live.com>
Signed-off-by: yusuf <yusuf@deeplearningmachine.mynet>
Signed-off-by: Yusuf Mohammad <79484377+YM2132@users.noreply.github.com>
Signed-off-by: <>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: yusuf <yusuf@deeplearningmachine.mynet>
2026-04-02 09:29:58 -04:00
bsliu
c0817e4d39 [Model] Add support for Cheers multimodal model (#38788)
Signed-off-by: bsliu <1187291748@qq.com>
Signed-off-by: 吴炳贤 <wubingxian24@mails.ucas.ac.cn>
2026-04-02 21:01:40 +08:00
Harry Mellor
dfe5e31689 Don't compile vision encoder for Transformers backend (#30518)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2026-04-02 12:42:29 +00:00
JartX
2ce3d0ce36 [Feature] KV cache per-token-head INT8/FP8 quantization (#38378)
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: yangyang4991 <yangyang4991@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
2026-04-02 08:13:26 -04:00
Jiangyun Zhu
4eefbf9609 [Perf] fuse kernels in gdn (#37813)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
2026-04-02 11:52:18 +00:00
vllmellm
551b3fb39f [ROCm] Enable VLLM triton FP8 moe for gfx1201, tuned for Qwen3-30B-A3B-FP8 tp=2 and Qwen/Qwen3.5-35B-A3B-FP8 tp=2 (#38086)
Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
2026-04-02 08:13:42 +00:00
Li, Jiang
c6f722b93e [CPU] Support gelu act in cpu_fused_moe (#38770)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
2026-04-02 14:14:32 +08:00
Xin Yang
9bd7231106 Revert "[Kernel] Add gpt-oss Router GEMM kernel (#37205)" (#38778)
Signed-off-by: Xin Yang <xyangx@amazon.com>
2026-04-01 22:02:32 -07:00
Yanan Cao
73f48ce559 [Kernel] [Helion] Use warning_once in get_gpu_name to prevent log spam (#38743)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
2026-04-01 21:30:31 -07:00
Gregory Shtrasberg
3aab680e3e [ROCm][Bugfix] Fix ROCm runtime failure due to missing symbol (#38750)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
2026-04-01 21:30:11 -07:00
Sergey Zinchenko
5a2d420c17 [Bugfix] Use dedicated MM processor cache in /tokenize to prevent sender-cache pollution (#38545)
Signed-off-by: Sergey Zinchenko <sergey.zinchenko.rnd@gmail.com>
2026-04-01 21:14:49 -07:00
Benjamin Chislett
5f96f9aff1 [Perf] DSV3.2 Indexer Fused Weights Projection (#38684)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
2026-04-02 03:34:49 +00:00
Luka Govedič
694449050f Fix multiline-format string for python 3.10 (#38739)
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
2026-04-02 03:19:35 +00:00
Nick Hill
6241521dd2 [BugFix] Fix precommit breakage due to conflicting in-flight merges (#38759)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
2026-04-01 15:35:06 -07:00
Kevin H. Luu
1785dc5501 Revert "[Bugfix] Fix Qwen3CoderToolParser anyOf/oneOf type resolution for nullable params (#37831)" (#38751) 2026-04-02 06:34:28 +08:00
Chang Su
54500546ac [Bugfix] Preserve original ImportError in gRPC server entrypoint (#38673)
Signed-off-by: Chang Su <chang.s.su@oracle.com>
2026-04-01 22:16:44 +00:00
Jeffrey Wang
de5e6c44c6 [Feat][Executor] Introduce RayExecutorV2 (#36836)
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
2026-04-01 14:34:29 -07:00
yzong-rh
cb268e4e55 [Refactor] Simplify FutureWrapper in MultiprocExecutor (#38644)
Signed-off-by: Yifan <yzong@redhat.com>
Signed-off-by: Yifan Zong <yzong@redhat.com>
2026-04-01 21:28:26 +00:00
Stefano Castagnetta
6183cae1bd [Bugfix] Restrict TRTLLM attention to SM100, fixing GB300 (SM103) hang (#38730)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
2026-04-01 12:08:40 -07:00
Monishver
c09ad767cd Feature/silu block quant fusion v1 (#32996)
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
2026-04-01 18:50:43 +00:00
Wentao Ye
c9a9db0e02 [Compile] Fix nvfp4 compile warning (#38573)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2026-04-01 18:28:57 +00:00
Chauncey
cbe7d18096 [Misc] Rename think_start_str/think_end_str to reasoning_start_str/reasoning_end_str (#38242)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2026-04-01 09:56:45 -07:00
Michael Goin
db5d0719e1 [Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp (#34664)
Signed-off-by: mgoin <mgoin64@gmail.com>
2026-04-01 09:41:42 -07:00
yzong-rh
dc0428ebb8 [NIXL][BUG] Fix Triton heterogeneous TP (#37940)
Signed-off-by: Yifan <yzong@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
2026-04-01 17:23:15 +02:00
Jesus Talavera
148c2072ec Add ibm-granite/granite-vision-3.3-2b to supported models documentation (#38714)
Signed-off-by: Jesus Talavera <jesus.talavera@ibm.com>
2026-04-01 08:22:25 -07:00
majianhan
2f5c3c1ec0 [Misc] Fix docstring typo: buildin -> builtin (#38722)
Co-authored-by: majianhan <majianhan@kylinos.cn>
2026-04-01 07:39:46 -07:00
Fynn Schmitt-Ulms
fa246d5231 Fix shape comment in extract_hidden_states example (#38723)
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
2026-04-01 07:29:33 -07:00
bnellnm
7cf56a59a2 [MoE Refactor] Make SharedExperts class for use with DefaultMoERunner (#35153)
Signed-off-by: Bill Nell <bnell@redhat.com>
2026-04-01 09:44:08 -04:00
Elvir Crnčević
5e30e9b9a9 [Bugfix] Revert "Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding" (#38359)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2026-04-01 09:11:10 -04:00
손세정
582340f273 [Bugfix] Fix Qwen3CoderToolParser anyOf/oneOf type resolution for nullable params (#37831)
Signed-off-by: AAISSJ <maze0717@g.skku.edu>
Signed-off-by: <>
Co-authored-by: 세덩 <saison@sedeong-ui-MacBookAir.local>
2026-04-01 20:22:29 +08:00
yjz
992368522f [KVTransfer] Fix TpKVTopology.is_kv_replicated equality case (#38179)
Signed-off-by: JianDan0212 <zhangyj0212@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
2026-04-01 12:41:49 +02:00
Juan Pérez de Algaba
58ee614221 (security) Enforce frame limit in VideoMediaIO (#38636)
Signed-off-by: jperezde <jperezde@redhat.com>
2026-04-01 10:23:45 +00:00
Harry Mellor
f9f6a9097a Add verified label to trigger pre-commit (#38708)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2026-04-01 02:31:02 -07:00
Zhanda Zhu
c75a313824 [Perf] triton bilinear_pos_embed kernel for ViT (#37948)
Signed-off-by: Zhanda Zhu <zhandazhu@gmail.com>
2026-04-01 01:52:02 -07:00
Lukas Geiger
4f6eed3bd4 [Core] Simplify multimodal masking (#34246)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
2026-04-01 01:18:22 -07:00
Li, Jiang
36d7f19897 [CPU] Support head_size 512 in cpu_attn (#38676)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
2026-04-01 05:42:27 +00:00
Jeffrey Wang
2d725b89c5 [Bugfix] Lazy import diskcache to avoid sqlite3/libstdc++ ImportError at startup (#38649)
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
2026-04-01 05:31:20 +00:00
Augusto Yao
ef53395e2c [bugfix] do not add extra linebreak for score/rerank with chat template (#38617)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2026-04-01 04:50:07 +00:00
Lucas Wilkinson
eb47454987 [Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2026-04-01 00:15:53 -04:00
Matthew Bonanni
116f4be405 [1/N][Cleanup] Standardize on use of is_quantized_kv_cache (#38659)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2026-04-01 04:08:01 +00:00
Wentao Ye
7b01d97a22 [Perf] Optimize mean pooling using chunks and index_add, 5.9% E2E throughput improvement (#38559)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2026-04-01 03:54:58 +00:00
HarshRathva
17b72fd1c8 Fix priority preemption regression test in scheduler (#37051)
Signed-off-by: HarshRathva <harshrathvaai@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
2026-04-01 06:36:12 +03:00
Samu Tamminen
c49497726b [ROCm][perf] Shuffle KV cache to use paged_attention_common (#32914)
Signed-off-by: Samu Tamminen <stammine@amd.com>
Co-authored-by: Tuukka Sarvi <tuukka.sarvi@amd.com>
2026-04-01 03:30:19 +00:00
Ben Browning
cb0b443274 [Misc] Add 20 regression tests for 11 tool parser bug fixes (#38172)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
2026-04-01 03:00:31 +00:00
Luka Govedič
40bb175027 [vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: chzhang <chaojun.zhang@intel.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
2026-03-31 22:15:05 -04:00
Elvir Crnčević
0fab52f0aa Fix NaN from stale FP4 scale padding in create_fp4_scale_tensor (#38148)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2026-03-31 19:14:59 -07:00
Yifan Qiao
91e4521f9f [Feat][v1] Simple yet General CPU KV Cache Offloading (#37160)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
2026-03-31 17:58:37 -07:00
Stig-Arne Grönroos
31a719bcd3 [ROCm][perf] fix Aiter sparse MLA with MTP>1 (#37887)
Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com>
Signed-off-by: Stig-Arne Grönroos <sgronroo@amd.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
2026-03-31 19:22:23 -04:00
Vedant V Jhaveri
2e56975657 Generative Scoring (#34539)
Signed-off-by: Vedant Jhaveri <vjhaveri@linkedin.com>
Co-authored-by: Vedant Jhaveri <vjhaveri@linkedin.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2026-03-31 16:02:11 -07:00
Chang Su
36f1dc19ae feat(grpc): add periodic stats logging and servicer log forwarding (#38333)
Signed-off-by: Chang Su <chang.s.su@oracle.com>
2026-03-31 15:50:07 -07:00
Asaf Gardin
3dc01ef352 [Quantization] Consolidate dummy format logic into DummyModelLoader (#38637)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
2026-03-31 22:20:45 +00:00
Yanan Cao
cc671cb110 [Kernel] [Helion] [17/N] Add Helion kernel torch.compile support (#38592)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
2026-03-31 17:06:42 -04:00
Wentao Ye
856589ed9a [Refactor] Remove dead code in kv connector and model runner (#38383)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2026-03-31 17:05:23 -04:00
czhu-cohere
517b769b58 [Perf] Fix DBO overlap: capture DeepEP event before yield (#38451)
Signed-off-by: root <conway.zhu@cohere.com>
2026-03-31 20:38:59 +00:00
yzong-rh
d9b90a07ac [MoE Refactor] Migrate Unquantized to Full Oracle Flow (#36286)
Signed-off-by: Yifan Zong <yzong@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: yzong-rh <yzong@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
2026-03-31 15:43:33 -04:00
Olya Kozlova
598190aac3 [fix] Remove trtllm ragged mla prefills (#36540)
Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
2026-03-31 12:30:27 -07:00
Xu Jinyang
b779eb3363 [Model] Sync upstream BT=chunk_size fix for GDN chunk_fwd_kernel_o, simplify warmup to single pass (#38343)
Signed-off-by: AuYang <459461160@qq.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
2026-03-31 23:03:24 +04:00
BadrBasowid
077a9a8e37 [torch.compile] Refactor Attention Quant Fusion Pass and Remove Boilerplate (#37373)
Signed-off-by: BadrBasowid <badr.basowid@gmail.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
2026-03-31 14:15:50 -04:00
Run Yu
07edd551cc [CI/Build] Resolve a dependency deadlock when installing the test dependencies used in CI (#37766)
Signed-off-by: Run Yu <yurun00@gmail.com>
2026-03-31 18:05:14 +00:00
mikaylagawarecki
7c080dd3c5 [4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
2026-03-31 10:21:13 -07:00
Yi Liu
0dd25a44ea [Quantization][Autoround][XPU] Add W4A16 Support (#37986)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
2026-03-31 16:48:24 +00:00
SandishKumarHN
3896e021a0 [Bugfix] Fix FusedMoE weight loading with padded hidden dimensions (#37010)
Signed-off-by: SandishKumarHN <sandish@fb.com>
2026-03-31 12:22:26 -04:00
419 changed files with 26801 additions and 4798 deletions

View File

@@ -42,6 +42,7 @@ docker run \
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
python3 examples/basic/offline_inference/generate.py --model OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc --block-size 64 --enforce-eager --max-model-len 8192
cd tests
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
pytest -v -s v1/engine

View File

@@ -790,7 +790,7 @@ steps:
- tests/kernels/helion/
- vllm/platforms/rocm.py
commands:
- pip install helion
- pip install helion==0.3.3
- pytest -v -s kernels/helion/

View File

@@ -72,6 +72,7 @@ steps:
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
- tests/compile/passes/test_fusion_attn.py
- tests/compile/passes/test_mla_attn_quant_fusion.py
- tests/compile/passes/test_silu_mul_quant_fusion.py
- tests/compile/passes/distributed/test_fusion_all_reduce.py
- tests/compile/fullgraph/test_full_graph.py
@@ -79,6 +80,7 @@ steps:
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
- nvidia-smi
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
- pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_devices=2 is not set
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py

View File

@@ -224,6 +224,20 @@ steps:
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 $IMAGE_TAG "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code"
- label: MessageQueue TCP Multi-Node (2 GPUs)
timeout_in_minutes: 10
working_dir: "/vllm-workspace/tests"
num_devices: 1
num_nodes: 2
no_plugin: true
optional: true
source_file_dependencies:
- vllm/distributed/device_communicators/shm_broadcast.py
- vllm/distributed/parallel_state.py
- tests/distributed/test_mq_tcp_multinode.py
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 1 $IMAGE_TAG "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py" "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py"
- label: Distributed NixlConnector PD accuracy (4 GPUs)
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"
@@ -294,3 +308,23 @@ steps:
commands:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py
- label: RayExecutorV2 (4 GPUs)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:
- vllm/v1/executor/ray_executor_v2.py
- vllm/v1/executor/abstract.py
- vllm/v1/executor/multiproc_executor.py
- tests/distributed/test_ray_v2_executor.py
- tests/distributed/test_ray_v2_executor_e2e.py
- tests/distributed/test_pipeline_parallel.py
- tests/basic_correctness/test_basic_correctness.py
commands:
- export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1
- export NCCL_CUMEM_HOST_ENABLE=0
- pytest -v -s distributed/test_ray_v2_executor.py
- pytest -v -s distributed/test_ray_v2_executor_e2e.py
- pytest -v -s distributed/test_pipeline_parallel.py -k "ray"
- TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray"

View File

@@ -2,6 +2,16 @@ group: Kernels
depends_on:
- image-build
steps:
- label: vLLM IR Tests
timeout_in_minutes: 10
working_dir: "/vllm-workspace/"
source_file_dependencies:
- vllm/ir
- vllm/kernels
commands:
- pytest -v -s tests/ir
- pytest -v -s tests/kernels/ir
- label: Kernels Core Operation Test
timeout_in_minutes: 75
source_file_dependencies:
@@ -19,6 +29,7 @@ steps:
- vllm/v1/attention
# TODO: remove this dependency (https://github.com/vllm-project/vllm/issues/32267)
- vllm/model_executor/layers/attention
- vllm/utils/flashinfer.py
- tests/kernels/attention
commands:
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
@@ -129,7 +140,7 @@ steps:
- vllm/utils/import_utils.py
- tests/kernels/helion/
commands:
- pip install helion
- pip install helion==0.3.3
- pytest -v -s kernels/helion/

View File

@@ -18,5 +18,6 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- pytest models/multimodal/generation/test_phi4siglip.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/generation/test_phi4siglip.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'

4
.github/CODEOWNERS vendored
View File

@@ -13,6 +13,9 @@
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
/vllm/model_executor/model_loader @22quinn
/vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/ir @ProExpertProg
/vllm/kernels/ @ProExpertProg @tjtanaa
/vllm/kernels/helion @ProExpertProg @zou3519
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
/tests/evals @mgoin @vadiklyutiy
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
/tests/kernels/ir @ProExpertProg @tjtanaa
/tests/models @DarkLight1337 @ywang96
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety

View File

@@ -28,6 +28,7 @@ jobs:
});
const hasReadyLabel = pr.labels.some(l => l.name === 'ready');
const hasVerifiedLabel = pr.labels.some(l => l.name === 'verified');
const { data: mergedPRs } = await github.rest.search.issuesAndPullRequests({
q: `repo:${context.repo.owner}/${context.repo.repo} is:pr is:merged author:${pr.user.login}`,
@@ -35,10 +36,10 @@ jobs:
});
const mergedCount = mergedPRs.total_count;
if (hasReadyLabel || mergedCount >= 4) {
core.info(`Check passed: ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
if (hasReadyLabel || hasVerifiedLabel || mergedCount >= 4) {
core.info(`Check passed: verified label=${hasVerifiedLabel}, ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
} else {
core.setFailed(`PR must have the 'ready' label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
core.setFailed(`PR must have the 'verified' or 'ready' (which also triggers tests) label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
}
pre-commit:

View File

@@ -39,7 +39,7 @@ repos:
rev: 0.11.1
hooks:
- id: pip-compile
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
args: [requirements/test.in, -c, requirements/common.txt, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
files: ^requirements/test\.(in|txt)$
- id: pip-compile
alias: pip-compile-rocm

View File

@@ -340,9 +340,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp")
"csrc/cutlass_extensions/common.cpp"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
@@ -489,59 +488,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# CUTLASS MLA Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
@@ -681,34 +627,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
@@ -760,7 +678,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu")
"csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -978,6 +899,96 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
#
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
#
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
#
# W4A8 kernels (moved from _C to _C_stable_libtorch)
#
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch
@@ -1019,7 +1030,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/gpt_oss_router_gemm.cu"
"csrc/moe/router_gemm.cu")
endif()

View File

@@ -0,0 +1,264 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark: Fused FP8 output quantization in merge_attn_states
Compares fused vs unfused approaches for producing FP8-quantized merged
attention output:
1. Fused CUDA -- single CUDA kernel (merge + FP8 quant)
2. Fused Triton -- single Triton kernel (merge + FP8 quant)
3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant
4. Unfused Triton -- Triton merge + torch.compiled FP8 quant
Usage:
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16
"""
import argparse
import itertools
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
)
# ---------------------------------------------------------------------------
# Configuration defaults
# ---------------------------------------------------------------------------
NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096]
# (label, num_heads, head_size) — num_heads is for TP=1
HEAD_CONFIGS = [
("DeepSeek-V3 MLA", 128, 128),
("Llama-70B", 64, 128),
("Llama-8B", 32, 128),
]
TP_SIZES = [1, 2, 4, 8]
INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
QUANTILES = [0.5, 0.2, 0.8]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def short_dtype(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")
def make_inputs(
num_tokens: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
"""Create random prefix/suffix outputs and LSEs."""
prefix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
suffix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
# Sprinkle some inf values to exercise edge-case paths
mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
prefix_lse[mask] = float("inf")
mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
suffix_lse[mask2] = float("inf")
return prefix_output, suffix_output, prefix_lse, suffix_lse
def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes):
"""Build (num_tokens, num_heads, head_size, dtype_str) config tuples,
applying TP division to num_heads and skipping invalid combos."""
configs = []
for (_, nh, hs), nt, dtype, tp in itertools.product(
head_configs, num_tokens_list, input_dtypes, tp_sizes
):
nh_tp = nh // tp
if nh_tp >= 1:
configs.append((nt, nh_tp, hs, short_dtype(dtype)))
return configs
def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark merge_attn_states fused FP8 quantization"
)
parser.add_argument(
"--num-tokens",
type=int,
nargs="+",
default=None,
help=f"Override token counts (default: {NUM_TOKENS_LIST})",
)
parser.add_argument(
"--tp",
type=int,
nargs="+",
default=None,
help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})",
)
parser.add_argument(
"--dtype",
type=str,
nargs="+",
default=None,
help="Input dtypes (e.g. bfloat16 float16 float32). "
f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}",
)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Parse args and build configs before decorators
# ---------------------------------------------------------------------------
args = parse_args()
num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST
tp_sizes = args.tp if args.tp else TP_SIZES
if args.dtype:
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype]
else:
input_dtypes = INPUT_DTYPES
configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes)
torch._dynamo.config.recompile_limit = 8888
# ---------------------------------------------------------------------------
# Benchmark function
# ---------------------------------------------------------------------------
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_heads", "head_size", "dtype_str"],
x_vals=configs,
line_arg="provider",
line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"],
line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"],
styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")],
ylabel="us",
plot_name="merge_attn_states FP8 (fused vs unfused)",
args={},
)
)
@default_vllm_config()
def benchmark(num_tokens, num_heads, head_size, dtype_str, provider):
input_dtype = getattr(torch, dtype_str)
fp8_dtype = current_platform.fp8_dtype()
prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs(
num_tokens, num_heads, head_size, input_dtype
)
output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda")
if provider == "fused_cuda":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_cuda(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "fused_triton":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_triton(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "unfused_cuda":
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_cuda(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
else: # unfused_triton
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_triton(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
device_name = current_platform.get_device_name()
print(f"Device: {device_name}")
print(f"Token counts: {num_tokens_list}")
print(f"TP sizes: {tp_sizes}")
print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}")
print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}")
benchmark.run(print_data=True)
if __name__ == "__main__":
with torch.inference_mode():
main()

View File

@@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from itertools import product
import torch
import torch.nn.functional as F
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
dtype: torch.dtype
group_size: int # Changed from list[int] to int
def description(self):
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x DT {self.dtype} "
f"x GS {self.group_size}"
)
def get_bench_params() -> list[bench_params_t]:
"""Test configurations covering common model sizes."""
NUM_TOKENS = [16, 128, 512, 2048]
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
DTYPES = [torch.float16, torch.bfloat16]
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
)
return bench_params
# Reference implementations
def unfused_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then per-tensor quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Per-tensor quantize (no group_size used here)
silu_out, _ = ops.scaled_fp8_quant(silu_out)
def unfused_groupwise_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then group-wise quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Group quantize - use group_size directly
silu_out, _ = per_token_group_quant_fp8(
silu_out, group_size=group_size, use_ue8m0=False
)
def fused_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
):
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
out, _ = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=quant_dtype,
is_scale_transposed=False,
)
# Bench functions
def bench_fn(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1
globals = {
"x": x,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(x, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
"""Run benchmarks for all implementations."""
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
scale = 1 / params.hidden_size
x = (
torch.randn(
params.num_tokens,
params.hidden_size * 2,
dtype=params.dtype,
device="cuda",
)
* scale
)
timers = []
# Unfused per-tensor FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# Unfused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)
# Fused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
"fused_groupwise_fp8_impl",
)
)
return timers
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def main():
torch.set_default_device("cuda")
bench_params = get_bench_params()
print(f"Running {len(bench_params)} benchmark configurations...")
print(
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
)
print()
timers = []
for bp in tqdm(bench_params):
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
timers.extend(result_timers)
print("\n" + "=" * 80)
print("FINAL COMPARISON - ALL RESULTS")
print("=" * 80)
print_timers(timers)
if __name__ == "__main__":
main()

View File

@@ -1,134 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
def get_batch_size_range(max_batch_size):
return [2**x for x in range(14) if 2**x <= max_batch_size]
def get_model_params(config):
if config.architectures[0] in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
):
num_experts = config.n_routed_experts
hidden_size = config.hidden_size
elif config.architectures[0] in ("GptOssForCausalLM",):
num_experts = config.num_local_experts
hidden_size = config.hidden_size
else:
raise ValueError(f"Unsupported architecture: {config.architectures}")
return num_experts, hidden_size
def get_benchmark(model, max_batch_size, trust_remote_code):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=get_batch_size_range(max_batch_size),
x_log=False,
line_arg="provider",
line_vals=[
"torch",
"vllm",
],
line_names=["PyTorch", "vLLM"],
styles=([("blue", "-"), ("red", "-")]),
ylabel="TFLOPs",
plot_name=f"{model} router gemm throughput",
args={},
)
)
def benchmark(batch_size, provider):
config = get_config(model=model, trust_remote_code=trust_remote_code)
num_experts, hidden_size = get_model_params(config)
mat_a = torch.randn(
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn(
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
bias = torch.randn(
num_experts, dtype=torch.bfloat16, device="cuda"
).contiguous()
is_hopper_or_blackwell = current_platform.is_device_capability(
90
) or current_platform.is_device_capability_family(100)
allow_dsv3_router_gemm = (
is_hopper_or_blackwell
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
)
allow_gpt_oss_router_gemm = (
is_hopper_or_blackwell
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
has_bias = False
if allow_gpt_oss_router_gemm:
has_bias = True
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
def runner():
if has_bias:
F.linear(mat_a, mat_b, bias)
else:
F.linear(mat_a, mat_b)
elif provider == "vllm":
def runner():
if allow_dsv3_router_gemm:
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
elif allow_gpt_oss_router_gemm:
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
else:
raise ValueError("Unsupported router gemm")
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
runner, quantiles=quantiles
)
def tflops(t_ms):
flops = 2 * batch_size * hidden_size * num_experts
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
parser.add_argument("--max-batch-size", default=16, type=int)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
# Get the benchmark function
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
# Run performance benchmark
benchmark.run(print_data=True)

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Benchmarks the fused Triton bilinear position-embedding kernel against
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
#
# == Usage Examples ==
#
# Default benchmark:
# python3 benchmark_vit_bilinear_pos_embed.py
#
# Custom parameters:
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
import itertools
import torch
from vllm.model_executor.models.qwen3_vl import (
pos_embed_interpolate_native,
triton_pos_embed_interpolate,
)
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# (h, w) configurations to benchmark
h_w_configs = [
(16, 16),
(32, 32),
(48, 48),
(64, 64),
(128, 128),
(32, 48),
(60, 80),
]
# Temporal dimensions
t_range = [1]
configs = list(itertools.product(t_range, h_w_configs))
def get_benchmark(
num_grid_per_side: int,
spatial_merge_size: int,
hidden_dim: int,
dtype: torch.dtype,
device: str,
):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["t", "h_w"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["native", "triton"],
line_names=["Native (PyTorch)", "Triton"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name=(
f"vit-bilinear-pos-embed-"
f"grid{num_grid_per_side}-"
f"dim{hidden_dim}-"
f"{dtype}"
),
args={},
)
)
def benchmark(t, h_w, provider):
h, w = h_w
torch.manual_seed(42)
embed_weight = (
torch.randn(
num_grid_per_side * num_grid_per_side,
hidden_dim,
device=device,
dtype=dtype,
)
* 0.25
)
quantiles = [0.5, 0.2, 0.8]
if provider == "native":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: pos_embed_interpolate_native(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
else:
assert HAS_TRITON, "Triton not available"
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_pos_embed_interpolate(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark bilinear position embedding interpolation."
)
parser.add_argument(
"--num-grid-per-side",
type=int,
default=48,
help="Position embedding grid size (default: 48 for Qwen3-VL)",
)
parser.add_argument(
"--spatial-merge-size",
type=int,
default=2,
help="Spatial merge size (default: 2)",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=1152,
help="Embedding hidden dimension (default: 1152 for Qwen3-VL)",
)
parser.add_argument(
"--device",
type=str,
choices=["cuda:0", "cuda:1"],
default="cuda:0",
)
parser.add_argument(
"--save-path",
type=str,
default="./vit_pos_embed/",
)
args = parser.parse_args()
dtype = torch.bfloat16
bench = get_benchmark(
args.num_grid_per_side,
args.spatial_merge_size,
args.hidden_dim,
dtype,
args.device,
)
bench.run(print_data=True, save_path=args.save_path)

View File

@@ -39,7 +39,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 29210221863736a08f71a866459e368ad1ac4a95
GIT_TAG c0ec424fd8a546d0cbbf4bf050bbcfe837c55afb
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@@ -3,22 +3,33 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <limits>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/w8a8/fp8/common.cuh"
#include "../dispatch_utils.h"
namespace vllm {
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, const uint NUM_THREADS>
template <typename scalar_t, typename output_t, const uint NUM_THREADS,
bool USE_FP8_OUTPUT>
__global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
output_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride) {
using pack_128b_t = uint4;
const uint output_head_stride, const uint prefix_num_tokens,
const float* output_scale) {
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
// Outputs store pack_size elements of output_t, which is smaller for FP8.
using input_pack_t = uint4;
using output_pack_t =
std::conditional_t<USE_FP8_OUTPUT,
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
uint4>;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -41,8 +52,45 @@ __global__ void merge_attn_states_kernel(
head_idx * output_head_stride;
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
output_t* output_head_ptr = output + dst_head_offset;
// Pre-invert scale: multiplication is faster than division
float fp8_scale_inv = 1.0f;
if constexpr (USE_FP8_OUTPUT) {
fp8_scale_inv = 1.0f / *output_scale;
}
// If token_idx >= prefix_num_tokens, just copy from suffix
if (token_idx >= prefix_num_tokens) {
if (pack_offset < head_size) {
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
suffix_head_ptr)[pack_offset / pack_size];
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
}
}
if (output_lse != nullptr && pack_idx == 0) {
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
output_lse[head_idx * num_tokens + token_idx] = s_lse;
}
return;
}
// For tokens within prefix range, merge prefix and suffix
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
@@ -53,20 +101,34 @@ __global__ void merge_attn_states_kernel(
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
continuing the pipeline then yields NaN. Root cause: with chunked prefill
a batch may be split into two chunks; if a request in that batch has no
prefix hit, every LSE entry for that requests position is -inf, and at
prefix hit, every LSE entry for that request's position is -inf, and at
this moment we merge cross-attention at first. For now we simply emit
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
this problem.
*/
if (std::isinf(max_lse)) {
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
prefix_head_ptr)[pack_offset / pack_size];
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
p_out_pack;
if constexpr (USE_FP8_OUTPUT) {
// Convert prefix values to FP8 (since -inf means no data,
// prefix_output is expected to be zeros)
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
}
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
@@ -84,30 +146,43 @@ __global__ void merge_attn_states_kernel(
const float s_scale = s_se / out_se;
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
prefix_head_ptr)[pack_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;
// Compute merged values in float32
float o_out_f[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale);
}
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
o_out_pack;
// Convert and store
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
o_out_f[i], fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
output_pack_t o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f[i]);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
}
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
@@ -134,50 +209,73 @@ __global__ void merge_attn_states_kernel(
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
USE_FP8_OUTPUT) \
{ \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \
USE_FP8_OUTPUT> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride); \
num_heads, head_size, prefix_head_stride, output_head_stride, \
prefix_num_tokens, output_scale_ptr); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states.
* @param prefill_tokens_with_context Number of prefill tokens with context
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
* is computed by merging prefix_output and suffix_output. For remaining tokens
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
* from suffix_output.
* @param output_scale Optional scalar tensor for FP8 static quantization.
* When provided, output must be FP8 dtype.
*/
template <typename scalar_t>
void merge_attn_states_launcher(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
void merge_attn_states_launcher(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint prefix_head_stride = prefix_output.stride(1);
const uint output_head_stride = output.stride(1);
// Thread mapping is based on input BF16 pack_size
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
const uint prefix_num_tokens =
prefill_tokens_with_context.has_value()
? static_cast<uint>(prefill_tokens_with_context.value())
: num_tokens;
TORCH_CHECK(prefix_num_tokens <= num_tokens,
"prefix_num_tokens must be <= num_tokens");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
float* output_scale_ptr = nullptr;
if (output_scale.has_value()) {
output_scale_ptr = output_scale.value().data_ptr<float>();
}
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const uint threads_per_head = head_size / pack_size;
@@ -189,14 +287,22 @@ void merge_attn_states_launcher(torch::Tensor& output,
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
if (output_scale.has_value()) {
// FP8 output path - dispatch on output FP8 type
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
});
} else {
// Original BF16/FP16/FP32 output path
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
}
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>( \
output, output_lse, prefix_output, prefix_lse, suffix_output, \
suffix_lse, prefill_tokens_with_context, output_scale); \
}
void merge_attn_states(torch::Tensor& output,
@@ -204,6 +310,21 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
const torch::Tensor& suffix_lse,
std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
if (output_scale.has_value()) {
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
"output must be FP8 when output_scale is provided, got: ",
output.scalar_type());
} else {
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
"output dtype (", output.scalar_type(),
") must match prefix_output dtype (",
prefix_output.scalar_type(), ") when output_scale is not set");
}
// Always dispatch on prefix_output (input) dtype
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
CALL_MERGE_ATTN_STATES_LAUNCHER);
}

View File

@@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,

View File

@@ -24,6 +24,8 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include <cuda.h>
#endif
#if defined(__gfx942__)
@@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
if (n == 0) return;
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
const int64_t* size_data = sizes.data_ptr<int64_t>();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
#else
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
#endif
}
namespace vllm {
// Grid: (num_layers, num_pairs)

View File

@@ -30,13 +30,15 @@
}()
namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") {
return FusedMOEAct::SiluAndMul;
} else if (act == "swigluoai") {
return FusedMOEAct::SwigluOAIAndMul;
} else if (act == "gelu") {
return FusedMOEAct::GeluAndMul;
} else {
TORCH_CHECK(false, "Invalid act type: " + act);
}
@@ -104,6 +106,43 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
}
}
template <typename scalar_t>
void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride, const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
vec_op::FP32Vec16 w1_vec(M_SQRT1_2);
vec_op::FP32Vec16 w2_vec(0.5);
alignas(64) float temp[16];
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto er_input_vec = gate_vec * w1_vec;
er_input_vec.save(temp);
for (int32_t i = 0; i < 16; ++i) {
temp[i] = std::erf(temp[i]);
}
vec_op::FP32Vec16 er_vec(temp);
auto gelu = gate_vec * w2_vec * (one_vec + er_vec);
auto gated_output_fp32 = up_vec * gelu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input,
@@ -118,6 +157,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
case FusedMOEAct::SiluAndMul:
silu_and_mul(input, output, m, n, input_stride, output_stride);
return;
case FusedMOEAct::GeluAndMul:
gelu_and_mul(input, output, m, n, input_stride, output_stride);
return;
default:
TORCH_CHECK(false, "Unsupported act type.");
}

View File

@@ -8,7 +8,7 @@ Generate CPU attention dispatch switch cases and kernel instantiations.
import os
# Head dimensions divisible by 32 (support all ISAs)
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256, 512]
# Head dimensions divisible by 16 but not 32 (VEC16 only)
HEAD_DIMS_16 = [80, 112]

View File

@@ -3,8 +3,8 @@
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <cassert>
#ifdef USE_ROCM

View File

@@ -1,7 +1,6 @@
#pragma once
#include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute {
////////////////////////////////////////////////////////////////////

View File

@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(

View File

@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(

View File

@@ -1,6 +1,21 @@
#pragma once
#include <torch/all.h>
// This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
@@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor,
static inline auto make_cute_layout(TorchTensor const& tensor,
std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
return tensor.stride(idx);
}
});
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
@@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
template <typename Stride>
static inline auto maybe_make_cute_layout(
std::optional<torch::Tensor> const& tensor,
std::optional<TorchTensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
@@ -121,12 +136,12 @@ template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<c10::Half> {
struct equivalent_cutlass_type<torch::headeronly::Half> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<c10::BFloat16> {
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
using type = cutlass::bfloat16_t;
};
@@ -134,8 +149,8 @@ struct equivalent_cutlass_type<c10::BFloat16> {
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
@@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = c10::Half;
using type = torch::headeronly::Half;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = c10::BFloat16;
using type = torch::headeronly::BFloat16;
};
// get equivalent c10::ScalarType tag from compile time type
// get equivalent torch::headeronly::ScalarType tag from compile time type
template <typename T>
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;

View File

@@ -49,6 +49,15 @@
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
// Half types dispatch (Half + BFloat16)
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
// Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \

View File

@@ -84,4 +84,54 @@ void get_cutlass_batched_moe_mm_data(
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
// FP4/NVFP4 ops
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets);
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& output_block_scale,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
#endif

View File

@@ -2,10 +2,9 @@
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
@@ -41,7 +40,7 @@ __global__ void get_group_gemm_starts(
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \
@@ -66,23 +65,34 @@ __global__ void get_group_gemm_starts(
namespace {
void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_group_scales.dtype() ==
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
TORCH_CHECK(out_tensors.dtype() ==
torch::kBFloat16); // only support bf16 for now
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
torch::stable::Tensor& b_group_scales_ptrs,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) {
STD_TORCH_CHECK(a_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(
b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Int); // int4 8x packed into int32
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(
b_group_scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
// type is e4m3
STD_TORCH_CHECK(
out_tensors.scalar_type() ==
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
torch::headeronly::ScalarType::Long);
int num_experts = static_cast<int>(expert_offsets.size(0));
// logical k, n
@@ -90,15 +100,16 @@ void run_get_group_gemm_starts(
int64_t k = a_tensors.size(1);
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace
} // namespace

View File

@@ -14,13 +14,12 @@
#include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "core/registration.h"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "w4a8_utils.cuh"
@@ -168,31 +167,40 @@ struct W4A8GroupedGemmKernel {
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
"LayoutB_Reordered size must be divisible by 4 bytes");
static void grouped_mm(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides) {
static void grouped_mm(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales,
const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes_torch,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides) {
auto device = a_tensors.device();
auto device_id = device.index();
const at::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream(device_id);
const torch::stable::accelerator::DeviceGuard device_guard(device_id);
auto stream = get_current_cuda_stream(device_id);
int num_experts = static_cast<int>(expert_offsets.size(0));
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(device);
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor out_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
// get the correct offsets to pass to gemm
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
@@ -247,9 +255,9 @@ struct W4A8GroupedGemmKernel {
// Allocate workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
torch::stable::Tensor workspace = torch::stable::empty(
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
device);
// Run GEMM
GemmShuffled gemm;
@@ -294,14 +302,20 @@ using Kernel_256x128_2x1x1_Coop =
using Kernel_128x256_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
void mm_dispatch(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides, const std::string& schedule) {
void mm_dispatch(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales,
const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
const std::string& schedule) {
if (schedule == "Kernel_128x16_1x1x1_Coop") {
Kernel_128x16_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
@@ -358,18 +372,23 @@ void mm_dispatch(
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else {
TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
STD_TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
}
}
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides,
void mm(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales, const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
std::optional<std::string> maybe_schedule) {
// user has specified a schedule
if (maybe_schedule) {
@@ -406,26 +425,27 @@ void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
a_strides, b_strides, c_strides, group_scale_strides, schedule);
}
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
torch::Tensor const& b_tensors) {
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
TORCH_CHECK(b_tensors.is_contiguous());
TORCH_CHECK(b_tensors.is_cuda());
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
STD_TORCH_CHECK(b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
STD_TORCH_CHECK(b_tensors.is_contiguous());
STD_TORCH_CHECK(b_tensors.is_cuda());
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
// These misalignments cause silent OOB unless run under Compute Sanitizer.
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
// we will store the layout to an int32 tensor;
// this is the number of elements we need per layout
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors);
int num_experts = static_cast<int>(b_tensors.size(0));
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
@@ -435,7 +455,7 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
size_t num_int4_elems = 1ull * num_experts * n * k;
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
num_int4_elems);
TORCH_CHECK(ok, "unified_encode_int4b failed");
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
// construct the layout once; assumes each expert has the same layout
using LayoutType = LayoutB_Reordered;
@@ -456,28 +476,28 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
}
// save the packed layout to torch tensor so we can re-use it
auto cpu_opts =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
torch::Tensor layout_cpu =
torch::empty({num_experts, layout_width}, cpu_opts);
torch::stable::Tensor layout_cpu = torch::stable::empty(
{num_experts, layout_width}, torch::headeronly::ScalarType::Int,
std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU));
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>();
for (int i = 0; i < num_experts; ++i) {
std::memcpy(layout_data + i * layout_width, // dst (int32*)
&layout_B_reordered, // src (LayoutType*)
sizeof(LayoutType)); // number of bytes
}
torch::Tensor packed_layout =
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
torch::stable::Tensor packed_layout =
torch::stable::to(layout_cpu, b_tensors.device(),
/*non_blocking=*/false);
return {b_tensors_packed, packed_layout};
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", &mm);
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm));
m.impl("cutlass_encode_and_reorder_int4b_grouped",
TORCH_BOX(&encode_and_reorder_int4b));
}
} // namespace vllm::cutlass_w4a8_moe
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -3,14 +3,12 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
//
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
#include "core/registration.h"
#include "cutlass/cutlass.h"
#include <limits>
@@ -161,31 +159,31 @@ struct W4A8GemmKernel {
using StrideD = typename GemmKernelShuffled::StrideD;
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
static torch::Tensor mm(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size,
torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type) {
static torch::stable::Tensor mm(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type) {
// TODO: param validation
int m = A.size(0);
int k = A.size(1);
int n = B.size(1);
// safely cast group_size to int
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
STD_TORCH_CHECK(
group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
int const group_size_int = static_cast<int>(group_size);
// Allocate output
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
auto device = A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
torch::Tensor D =
torch::empty({m, n}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<ElementD>)
.device(device));
auto stream = get_current_cuda_stream(device.index());
torch::stable::Tensor D = torch::stable::empty(
{m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device);
// prepare arg pointers
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
@@ -237,9 +235,9 @@ struct W4A8GemmKernel {
// Workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
torch::stable::Tensor workspace = torch::stable::empty(
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
device);
// Run GEMM
GemmShuffled gemm;
@@ -269,14 +267,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
torch::Tensor mm_dispatch(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size,
torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type,
const std::string& schedule) {
torch::stable::Tensor mm_dispatch(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
const std::string& schedule) {
if (schedule == "256x128_1x1x1") {
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
channel_scales, token_scales,
@@ -318,17 +316,18 @@ torch::Tensor mm_dispatch(torch::Tensor const& A,
channel_scales, token_scales,
maybe_out_type);
}
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
return {};
}
torch::Tensor mm(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size, torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
torch::stable::Tensor mm(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
// requested a specific schedule
if (maybe_schedule) {
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
@@ -378,14 +377,15 @@ torch::Tensor mm(torch::Tensor const& A,
// ----------------------------------------------------------------------------
// Pre-processing utils
// ----------------------------------------------------------------------------
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(scales.is_cuda());
torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
STD_TORCH_CHECK(scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(scales.is_contiguous());
STD_TORCH_CHECK(scales.is_cuda());
auto packed_scales = torch::empty(
{scales.numel() * ScalePackSize},
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
auto packed_scales =
torch::stable::empty({scales.numel() * ScalePackSize},
scales.scalar_type(), std::nullopt, scales.device());
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
auto packed_scales_ptr =
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
@@ -396,15 +396,16 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return packed_scales;
}
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
TORCH_CHECK(B.dtype() == torch::kInt32);
TORCH_CHECK(B.dim() == 2);
torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(B.dim() == 2);
torch::Tensor B_packed = torch::empty_like(B);
torch::stable::Tensor B_packed = torch::stable::empty_like(B);
int k = B.size(0) * PackFactor; // logical k
int n = B.size(1);
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
STD_TORCH_CHECK((n * k) % 32 == 0,
"need multiples of 32 int4s for 16B chunks");
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
@@ -415,16 +416,17 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
n * k);
TORCH_CHECK(ok, "unified_encode_int4b failed");
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
return B_packed;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_mm", &mm);
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm));
m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8));
m.impl("cutlass_encode_and_reorder_int4b",
TORCH_BOX(&encode_and_reorder_int4b));
}
} // namespace vllm::cutlass_w4a8
} // namespace vllm::cutlass_w4a8

View File

@@ -14,16 +14,15 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -118,17 +117,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
torch::Tensor& output_sf,
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& input_sf) {
void silu_and_mul_nvfp4_quant_sm1xxa(
torch::stable::Tensor& output, // [..., d]
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input, // [..., 2 * d]
torch::stable::Tensor& input_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1) / 2;
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -136,8 +137,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
@@ -149,7 +151,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());

View File

@@ -14,14 +14,12 @@
* limitations under the License.
*/
#include "core/registration.h"
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
@@ -122,7 +120,7 @@ __global__ void __get_group_gemm_starts(
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
@@ -150,50 +148,64 @@ __global__ void __get_group_gemm_starts(
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& alphas,
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes, int M, int N, int K) {
void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
const torch::stable::Tensor& b_starts,
const torch::stable::Tensor& out_starts,
const torch::stable::Tensor& a_scales_starts,
const torch::stable::Tensor& b_scales_starts,
const torch::stable::Tensor& alpha_starts,
const torch::stable::Tensor& layout_sfa,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& alphas,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& sf_offsets,
torch::stable::Tensor const& problem_sizes,
int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
STD_TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
cutlass::float_e2m1_t, cutlass::float_ue4m3_t,
torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA,
LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t, torch::kFloat16,
half, LayoutSFA, LayoutSFB, ScaleConfig)
cutlass::float_ue4m3_t,
torch::headeronly::ScalarType::Half, half,
LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
@@ -272,20 +284,40 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -308,7 +340,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
@@ -350,32 +382,35 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
@@ -446,20 +481,40 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -480,7 +535,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
@@ -523,33 +578,36 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
if (version_num >= 120 && version_num < 130) {
@@ -567,7 +625,7 @@ void run_fp4_blockwise_scaled_group_mm(
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100 or 120");
@@ -575,26 +633,31 @@ void run_fp4_blockwise_scaled_group_mm(
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
#endif
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets) {
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation
@@ -602,30 +665,34 @@ void cutlass_fp4_group_mm(
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas");
TORCH_CHECK(a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32.");
STD_TORCH_CHECK(
a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
STD_TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
STD_TORCH_CHECK(problem_sizes.dim() == 2,
"problem_sizes must be a 2D tensor");
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
STD_TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) {
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
@@ -633,7 +700,7 @@ void cutlass_fp4_group_mm(
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
int32_t version_num = get_sm_version_num();
if (version_num >= 120 && version_num < 130) {
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
output.scalar_type());
}
@@ -643,7 +710,7 @@ void cutlass_fp4_group_mm(
expert_offsets, sf_offsets, M, N, K);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
@@ -651,6 +718,6 @@ void cutlass_fp4_group_mm(
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm));
}

View File

@@ -14,16 +14,15 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
@@ -327,25 +326,28 @@ void quant_impl(void* output, void* output_scale, void* input,
} // namespace vllm
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
constexpr auto HALF = torch::headeronly::ScalarType::Half;
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
constexpr auto FLOAT = torch::headeronly::ScalarType::Float;
constexpr auto INT = torch::headeronly::ScalarType::Int;
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
// Common validation for fp4 experts quantization entry points.
static void validate_fp4_experts_quant_inputs(
torch::Tensor const& output, torch::Tensor const& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
torch::stable::Tensor const& output,
torch::stable::Tensor const& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
int64_t k) {
CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale");
@@ -354,41 +356,42 @@ static void validate_fp4_experts_quant_inputs(
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output.dim() == 2);
STD_TORCH_CHECK(output_scale.dim() == 2);
STD_TORCH_CHECK(input.dim() == 2);
STD_TORCH_CHECK(input_global_scale.dim() == 1);
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
STD_TORCH_CHECK(output.scalar_type() == UINT8);
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output.size(0) == m_topk);
STD_TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
}
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
auto k = input.size(1);
@@ -397,11 +400,11 @@ void scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
@@ -413,14 +416,15 @@ void scaled_fp4_experts_quant_sm1xxa(
}
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
// Input has gate || up layout, so k = input.size(1) / 2
auto k_times_2 = input.size(1);
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2;
validate_fp4_experts_quant_inputs(output, output_scale, input,
@@ -428,11 +432,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(

View File

@@ -0,0 +1,175 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
#endif
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 quantization kernel");
}
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::stable::empty(
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
torch::stable::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::stable::empty(
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
} else {
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
torch::headeronly::ScalarType::Byte,
std::nullopt, device);
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -14,16 +14,16 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -173,18 +173,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -192,8 +193,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
@@ -213,15 +215,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
} else {
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
@@ -229,15 +231,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(
m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}
}

View File

@@ -14,32 +14,39 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& A_sf,
const torch::Tensor& B_sf,
const torch::Tensor& alpha) {
// Make sure were on As device.
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
const torch::stable::Tensor& A,
const torch::stable::Tensor& B,
const torch::stable::Tensor& A_sf,
const torch::stable::Tensor& B_sf,
const torch::stable::Tensor& alpha) {
// Make sure we're on A's device.
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
@@ -56,8 +63,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {

View File

@@ -14,10 +14,9 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
@@ -127,8 +126,9 @@ struct Fp4GemmSm100 {
template <typename Config>
typename Config::Gemm::Arguments args_from_options(
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
int64_t M, int64_t N, int64_t K) {
using ElementA = typename Config::Gemm::ElementA;
using ElementB = typename Config::Gemm::ElementB;
@@ -174,19 +174,20 @@ typename Config::Gemm::Arguments args_from_options(
}
template <typename Config>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
cudaStream_t stream) {
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
typename Config::Gemm gemm;
auto arguments =
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -197,12 +198,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
// Dispatch function to select appropriate config based on M
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 16) {
@@ -222,61 +224,65 @@ void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
#else
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.size(0), "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
auto const m = A.size(0);
auto const n = B.size(0);
auto const k = A.size(1) * 2;
constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
@@ -285,33 +291,34 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
// integer.
int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
"x", B_sf.sizes()[1], ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.sizes()[1], ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.sizes()[1], ")");
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
B_sf.size(1), ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.size(1), ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.size(1), ")");
auto out_dtype = D.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
auto out_dtype = D.scalar_type();
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::Half) {
if (out_dtype == torch::headeronly::ScalarType::Half) {
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) {
} else if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
m, n, k, stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
")");
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (",
out_dtype, ")");
}
}

View File

@@ -14,10 +14,9 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
@@ -34,19 +33,20 @@
using namespace cute;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>;
@@ -109,12 +109,13 @@ struct Fp4GemmSm120 {
};
template <typename Gemm>
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
torch::Tensor const& alpha, int M,
int N, int K) {
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha,
int M, int N, int K) {
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD;
@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
}
template <typename Gemm>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
torch::Tensor const& alpha, int M, int N, int K,
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int M, int N, int K,
cudaStream_t stream) {
Gemm gemm;
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
}
}
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
}
}
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.size(0), "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
auto const m = A.size(0);
auto const n = B.size(0);
auto const k = A.size(1) * 2;
constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
@@ -248,38 +254,39 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
// integer.
int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
"x", B_sf.sizes()[1], ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.sizes()[1], ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.sizes()[1], ")");
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
B_sf.size(1), ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.size(1), ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.size(1), ")");
auto out_dtype = D.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
auto out_dtype = D.scalar_type();
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::BFloat16) {
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else if (out_dtype == at::ScalarType::Half) {
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
}
#else
TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
}
}

View File

@@ -20,7 +20,7 @@
#include <cuda_fp8.h>
#include <utility>
#include "../../cuda_vec_utils.cuh"
#include "cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090

View File

@@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
// Out variant
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
// This registration is now migrated to stable ABI
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
// yet in torch/headeronly), the tag should be applied from Python
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
// with the .impl remaining in C++.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
#endif
}
@@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
// FP4/NVFP4 ops
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
#endif
}
@@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
ops.impl("cutlass_scaled_mm_supports_fp4",
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
#endif
}

View File

@@ -2,6 +2,7 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h>

View File

@@ -1,144 +0,0 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "gpt_oss_router_gemm.cuh"
void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB,
__nv_bfloat16* gC, __nv_bfloat16* bias,
int batch_size, int output_features,
int input_features, cudaStream_t stream) {
static int const WARP_TILE_M = 16;
static int const TILE_M = WARP_TILE_M;
static int const TILE_N = 8;
static int const TILE_K = 64;
static int const STAGES = 16;
static int const STAGE_UNROLL = 4;
static bool const PROFILE = false;
CUtensorMap weight_map{};
CUtensorMap activation_map{};
constexpr uint32_t rank = 2;
uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features};
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
uint32_t box_size[rank] = {TILE_K, TILE_M};
uint32_t elem_stride[rank] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(
&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank,
gB, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for weight_map, error code=",
static_cast<int>(res));
size[1] = batch_size;
box_size[1] = TILE_N;
res = cuTensorMapEncodeTiled(
&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
rank, gA, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for activation_map, error code=",
static_cast<int>(res));
int smem_size = STAGES * STAGE_UNROLL *
(TILE_M * TILE_K * sizeof(__nv_bfloat16) +
TILE_N * TILE_K * sizeof(__nv_bfloat16));
gpuErrChk(cudaFuncSetAttribute(
gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
dim3 grid(tiles_m, tiles_n);
dim3 block(384);
cudaLaunchConfig_t config;
cudaLaunchAttribute attrs[1];
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
config.attrs = attrs;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
cudaLaunchKernelEx(
&config,
&gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map,
activation_map, nullptr);
}
void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output,
torch::Tensor input, torch::Tensor weight,
torch::Tensor bias) {
auto const batch_size = input.size(0);
auto const input_dim = input.size(1);
auto const output_dim = weight.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::BFloat16) {
launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(),
(__nv_bfloat16*)weight.data_ptr(),
(__nv_bfloat16*)output.mutable_data_ptr(),
(__nv_bfloat16*)bias.data_ptr(), batch_size,
output_dim, input_dim, stream);
} else {
throw std::invalid_argument("Unsupported dtype, only supports bfloat16");
}
}
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias) {
TORCH_CHECK(input.dim() == 2, "input must be 2D");
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");
TORCH_CHECK(input.sizes()[1] == weight.sizes()[1],
"input.size(1) must match weight.size(1)");
TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0],
"weight.size(0) must match bias.size(0)");
TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16,
"input tensor must be bfloat16");
TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16,
"weight tensor must be bfloat16");
TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16,
"bias tensor must be bfloat16");
gpt_oss_router_gemm_cuda_forward(output, input, weight, bias);
}

View File

@@ -1,447 +0,0 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_bf16.h"
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "cuda_pipeline.h"
#include <cuda.h>
#include <cuda/barrier>
#include <cuda/std/utility>
#include <cuda_runtime.h>
using barrier = cuda::barrier<cuda::thread_scope_block>;
namespace cde = cuda::device::experimental;
namespace ptx = cuda::ptx;
#define gpuErrChk(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
}
inline void gpuAssert(cudaError_t code, char const* file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort) {
throw std::runtime_error(cudaGetErrorString(code));
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
__device__ uint64_t gclock64() {
unsigned long long int rv;
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv));
return rv;
}
__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) {
int dst;
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = dst;
}
__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) {
int x, y;
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(x), "=r"(y)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = x;
rvi[1] = y;
}
__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) {
int x, y, z, w;
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = x;
rvi[1] = y;
rvi[2] = z;
rvi[3] = w;
}
__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2],
float c[4]) {
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
float const* C = reinterpret_cast<float const*>(&c[0]);
float* D = reinterpret_cast<float*>(&d[0]);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]),
"f"(C[3]));
}
__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4],
float c[4]) {
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
float const* C = reinterpret_cast<float const*>(&c[0]);
float* D = reinterpret_cast<float*>(&d[0]);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
}
__device__ void bar_wait(uint32_t bar_ptr, int phase) {
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(bar_ptr),
"r"(phase));
}
__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) {
uint32_t success;
#ifdef INTERNAL
asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory");
#endif
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(success)
: "r"(bar_ptr), "r"(phase));
return success;
}
__device__ uint32_t elect_one_sync() {
uint32_t pred = 0;
uint32_t laneid = 0;
asm volatile(
"{\n"
".reg .b32 %%rx;\n"
".reg .pred %%px;\n"
" elect.sync %%rx|%%px, %2;\n"
"@%%px mov.s32 %1, 1;\n"
" mov.s32 %0, %%rx;\n"
"}\n"
: "+r"(laneid), "+r"(pred)
: "r"(0xFFFFFFFF));
return pred;
}
#endif
struct Profile {
uint64_t start;
uint64_t weight_load_start;
uint64_t act_load_start;
uint64_t compute_start;
uint64_t complete;
};
template <int WARP_TILE_M, int TILE_M, int TILE_N, int TILE_K, int STAGES,
int STAGE_UNROLL, bool PROFILE>
__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel(
__nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations,
__nv_bfloat16* bias, int M, int N, int K,
const __grid_constant__ CUtensorMap weight_map,
const __grid_constant__ CUtensorMap activation_map,
Profile* profile = nullptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0)
profile[blockIdx.x].start = gclock64();
extern __shared__ __align__(128) char smem[];
__nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0];
__nv_bfloat16* sh_activations =
(__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K *
sizeof(__nv_bfloat16)];
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ barrier bar_wt_ready[STAGES];
__shared__ barrier bar_act_ready[STAGES];
__shared__ barrier bar_data_consumed[STAGES];
__shared__ float4 reduction_buffer[128];
__shared__ nv_bfloat16 sh_bias[TILE_M];
if (threadIdx.x == 0) {
for (int i = 0; i < STAGES; i++) {
init(&bar_wt_ready[i], 1);
init(&bar_act_ready[i], 1);
init(&bar_data_consumed[i], 32);
}
ptx::fence_proxy_async(ptx::space_shared);
asm volatile("prefetch.tensormap [%0];"
:
: "l"(reinterpret_cast<uint64_t>(&weight_map))
: "memory");
asm volatile("prefetch.tensormap [%0];"
:
: "l"(reinterpret_cast<uint64_t>(&activation_map))
: "memory");
}
__syncthreads();
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
int phase = 0;
int mib = blockIdx.x * TILE_M;
int ni = blockIdx.y * TILE_N;
float accum[4];
for (int i = 0; i < 4; i++) accum[i] = 0.f;
int const K_LOOPS_DMA =
(K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL));
int const K_LOOPS_COMPUTE = K_LOOPS_DMA;
// Data loading thread
if (warp_id >= 4 && elect_one_sync()) {
int stage = warp_id % 4;
bool weight_warp = warp_id < 8;
if (!weight_warp) {
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
for (int ki = 0; ki < K_LOOPS_DMA; ki++) {
int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL;
uint64_t desc_ptr_wt = reinterpret_cast<uint64_t>(&weight_map);
uint64_t desc_ptr_act = reinterpret_cast<uint64_t>(&activation_map);
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16);
int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16);
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
if (weight_warp)
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt));
if (!weight_warp)
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act));
if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp)
profile[blockIdx.x].weight_load_start = gclock64();
if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp)
profile[blockIdx.x].act_load_start = gclock64();
for (int i = 0; i < STAGE_UNROLL; i++) {
uint32_t smem_ptr_wt = __cvta_generic_to_shared(
&sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]);
uint32_t crd0 = k + i * TILE_K;
uint32_t crd1 = mib;
if (weight_warp)
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
: "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0),
"r"(crd1)
: "memory");
uint32_t smem_ptr_act = __cvta_generic_to_shared(
&sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]);
crd0 = k + i * TILE_K;
crd1 = ni;
if (!weight_warp)
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
: "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act),
"r"(crd0), "r"(crd1)
: "memory");
}
stage += 4;
if (stage >= STAGES) {
stage = warp_id % 4;
phase ^= 1;
}
}
// Wait for pending loads to be consumed before exiting, to avoid race
for (int i = 0; i < (STAGES / 4) - 1; i++) {
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
stage += 4;
if (stage >= STAGES) {
stage = warp_id % 4;
phase ^= 1;
}
}
}
// Compute threads
else if (warp_id < 4) {
// Sneak the bias load into the compute warps since they're just waiting for
// stuff anyway
if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x];
int stage = warp_id;
int phase = 0;
int lane_id_div8 = lane_id / 8;
int lane_id_mod8 = lane_id % 8;
int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0;
int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0;
int row_wt = lane_id_mod8 + lane_row_offset_wt;
int row_act = lane_id_mod8;
int row_offset_wt = (reinterpret_cast<uintptr_t>(sh_weights) / 128) % 8;
int row_offset_act = row_offset_wt;
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
bool weight_ready = bar_try_wait(bar_ptr_wt, phase);
bool act_ready = bar_try_wait(bar_ptr_act, phase);
#pragma unroll 2
for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) {
int next_stage = stage + 4;
int next_phase = phase;
if (next_stage >= STAGES) {
next_stage = warp_id;
next_phase ^= 1;
}
while (!weight_ready || !act_ready) {
weight_ready = bar_try_wait(bar_ptr_wt, phase);
act_ready = bar_try_wait(bar_ptr_act, phase);
}
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0)
profile[blockIdx.x].compute_start = gclock64();
if (ki + 1 < K_LOOPS_COMPUTE) {
weight_ready = bar_try_wait(
__cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase);
act_ready = bar_try_wait(
__cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase);
}
#pragma unroll
for (int su = 0; su < STAGE_UNROLL; su++) {
__nv_bfloat16* ptr_weights =
&sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K];
__nv_bfloat16* ptr_act =
&sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K];
#pragma unroll
for (int kii = 0; kii < TILE_K / 16; kii++) {
__nv_bfloat16 a[8];
__nv_bfloat16 b[4];
int col = 2 * kii + lane_col_offset_wt;
int col_sw = ((row_wt + row_offset_wt) % 8) ^ col;
ldmatrix4(a, __cvta_generic_to_shared(
&ptr_weights[row_wt * TILE_K + col_sw * 8]));
col = 2 * kii + lane_id_div8;
col_sw = ((row_act + row_offset_act) % 8) ^ col;
ldmatrix2(b, __cvta_generic_to_shared(
&ptr_act[row_act * TILE_K + 8 * col_sw]));
HMMA_16816(accum, a, b, accum);
}
}
uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]);
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c));
stage = next_stage;
phase = next_phase;
}
float4 accum4;
accum4.x = accum[0];
accum4.y = accum[1];
accum4.z = accum[2];
accum4.w = accum[3];
reduction_buffer[threadIdx.x] = accum4;
__syncthreads();
if (warp_id == 0) {
int mi = mib + warp_id * WARP_TILE_M;
int tm = mi + lane_id / 4;
int tn = ni + 2 * (lane_id % 4);
float4 accum1 = reduction_buffer[32 + threadIdx.x];
float4 accum2 = reduction_buffer[64 + threadIdx.x];
float4 accum3 = reduction_buffer[96 + threadIdx.x];
accum[0] = accum[0] + accum1.x + accum2.x + accum3.x;
accum[1] = accum[1] + accum1.y + accum2.y + accum3.y;
accum[2] = accum[2] + accum1.z + accum2.z + accum3.z;
accum[3] = accum[3] + accum1.w + accum2.w + accum3.w;
float bias_lo = __bfloat162float(sh_bias[tm - mib]);
float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]);
if (tn < N && tm < M)
output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo);
if (tn + 1 < N && tm < M)
output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo);
if (tn < N && tm + 8 < M)
output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi);
if (tn + 1 < N && tm + 8 < M)
output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi);
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
profile[blockIdx.x].complete = gclock64();
}
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

View File

@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],

View File

@@ -343,6 +343,8 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) {
static_assert(group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -357,9 +359,10 @@ __global__ void Marlin(
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
is_a_8bit || b_type == vllm::kFE4M3fn ||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -373,7 +376,7 @@ __global__ void Marlin(
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride =
prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
@@ -692,9 +695,8 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
@@ -1131,7 +1133,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
@@ -1140,7 +1142,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
@@ -1341,7 +1343,7 @@ __global__ void Marlin(
}
}
if constexpr (b_type == vllm::kFE2M1f) {
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -70,8 +70,4 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a,
const torch::Tensor& mat_b);
// gpt-oss optimized router GEMM kernel for SM90+
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias);
#endif

View File

@@ -132,12 +132,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file
// gpt-oss optimized router GEMM kernel for SM90+
m.def(
"gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, "
"Tensor bias) -> ()");
m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm);
#endif
}

View File

@@ -53,12 +53,12 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
void merge_attn_states(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale = std::nullopt);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
@@ -143,6 +143,14 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);
#ifndef USE_ROCM
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed);
#endif
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
@@ -152,12 +160,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
#ifndef USE_ROCM
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& output_block_scale,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
@@ -225,44 +227,6 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp);

View File

@@ -1,163 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
torch::Tensor& output_sf,
torch::Tensor& input,
torch::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::empty(
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
torch::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::empty(
{sf_m, sf_n},
torch::TensorOptions().device(device).dtype(torch::kInt32));
} else {
output_sf = torch::empty(
{m, n / CVT_FP4_SF_VEC_SIZE},
torch::TensorOptions().device(device).dtype(torch::kUInt8));
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
torch::Tensor& input, torch::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../dispatch_utils.h"
#include "quant_conversions.cuh"
#include "../w8a8/fp8/common.cuh"
namespace vllm {
// Logic: one thread block per (token, group) pair
template <typename scalar_t, typename scalar_out_t, bool is_scale_transposed,
int32_t group_size>
__global__ void silu_and_mul_per_block_quant_kernel(
scalar_out_t* __restrict__ out, // Output: [num_tokens, hidden_size] in
// FP8/INT8
float* __restrict__ scales, // Output: [num_tokens, hidden_size /
// group_size] or [hidden_size / group_size,
// num_tokens]
scalar_t const* __restrict__ input, // Input: [num_tokens, hidden_size * 2]
float const* scale_ub, // Optional scale upper bound
int32_t const hidden_size // Output hidden size (input is 2x this)
) {
static_assert((group_size & (group_size - 1)) == 0,
"group_size must be a power of 2 for correct reduction");
// Grid: (num_tokens, num_groups)
int const token_idx = blockIdx.x;
int const group_idx = blockIdx.y;
int const tid = threadIdx.x; // tid in [0, group_size)
int const num_tokens = gridDim.x;
// Input layout: [gate || up] concatenated along last dimension
int const input_stride = hidden_size * 2;
int const group_start = group_idx * group_size;
// Pointers to this token's data
scalar_t const* token_input_gate =
input + token_idx * input_stride + group_start;
scalar_t const* token_input_up = token_input_gate + hidden_size;
scalar_out_t* token_output = out + token_idx * hidden_size + group_start;
// Scale pointer for this group
int const num_groups = gridDim.y;
float* group_scale_ptr = is_scale_transposed
? scales + group_idx * num_tokens + token_idx
: scales + token_idx * num_groups + group_idx;
// Shared memory for reduction (compile-time sized)
__shared__ float shared_max[group_size];
// Step 1: Each thread loads one element, computes SiLU, stores in register
float gate = static_cast<float>(token_input_gate[tid]);
float up = static_cast<float>(token_input_up[tid]);
// Compute SiLU(gate) * up
float sigmoid_gate = 1.0f / (1.0f + expf(-gate));
float silu_gate = gate * sigmoid_gate;
float result = silu_gate * up; // Keep in register
// Step 2: Reduce to find group max
shared_max[tid] = fabsf(result);
__syncthreads();
// Power-of-2 reduction (group_size guaranteed to be power of 2)
#pragma unroll
for (int stride = group_size / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]);
}
__syncthreads();
}
// Step 3: Compute scale (thread 0), broadcast via shared memory
if (tid == 0) {
float group_max = shared_max[0];
float const quant_range = quant_type_max_v<scalar_out_t>;
float group_scale = group_max / quant_range;
// Apply scale upper bound if provided
if (scale_ub != nullptr) {
group_scale = fminf(group_scale, *scale_ub);
}
// Use minimum safe scaling factor
group_scale = fmaxf(group_scale, min_scaling_factor<scalar_out_t>::val());
// Store scale to global memory
*group_scale_ptr = group_scale;
// Reuse shared_max[0] to broadcast scale
shared_max[0] = group_scale;
}
__syncthreads();
float group_scale = shared_max[0];
// Step 4: Quantize and write output
token_output[tid] =
vllm::ScaledQuant<scalar_out_t, false>::quant_fn(result, group_scale);
}
} // namespace vllm
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
"Input must be FP16 or BF16");
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size);
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
int32_t hidden_size = out.size(-1);
auto num_tokens = input.size(0);
int32_t num_groups = hidden_size / group_size;
TORCH_CHECK(input.size(-1) == hidden_size * 2,
"input last dim must be 2x output hidden_size");
TORCH_CHECK(hidden_size % group_size == 0,
"hidden_size must be divisible by group_size");
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens, num_groups);
dim3 block(group_size);
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_in_t = scalar_t;
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_out_t = scalar_t;
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
vllm::silu_and_mul_per_block_quant_kernel<
scalar_in_t, scalar_out_t, transpose_scale, gs>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_out_t>(),
scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
});
});
});
});
}

View File

@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],

View File

@@ -591,6 +591,9 @@ torch::Tensor marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -327,6 +327,9 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (s_type == vllm::kFE8M0fnu) {
// MXFP8: FP8 weights with e8m0 microscaling block scales
static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -334,6 +337,7 @@ __global__ void Marlin(
}
constexpr bool is_a_8bit = a_type.size_bits() == 8;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
if constexpr (!is_a_8bit) {
static_assert(std::is_same<scalar_t, c_scalar_t>::value);
}
@@ -343,7 +347,7 @@ __global__ void Marlin(
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
is_a_8bit || b_type == vllm::kFE4M3fn ||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -555,9 +559,8 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
@@ -997,7 +1000,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
@@ -1006,7 +1009,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
@@ -1207,7 +1210,7 @@ __global__ void Marlin(
}
}
if constexpr (b_type == vllm::kFE2M1f) {
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -2,7 +2,6 @@
#include "cuda_utils.h"
#include "ops.h"
#include "core/registration.h"
#include <torch/library.h>
#include <torch/version.h>
@@ -73,7 +72,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
" Tensor suffix_lse,"
" int!? prefill_tokens_with_context,"
" Tensor? output_scale=None) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(
@@ -109,13 +110,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
#ifndef USE_ROCM
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant);
#endif
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
@@ -239,6 +233,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops
#ifndef USE_ROCM
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
@@ -332,47 +337,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// conditionally compiled so impl registration is in source file
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
// conditionally compiled so impl registration is in source file
#endif
// Dequantization for GGML.
@@ -409,20 +373,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// conditionally compiled so impl registration is in source file
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
ops.def(
"mxfp8_experts_quant("
@@ -455,44 +405,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> int");
// conditionally compiled so impl in source file
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
// Out variant
// TODO: Add {at::Tag::out_variant} tag and update all call sites
// to use the functional variant once vLLM upgrades PyTorch.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
&silu_and_mul_scaled_fp4_experts_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
#endif
// Quantized GEMM for GPTQ.
@@ -596,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"

View File

@@ -203,7 +203,8 @@ WORKDIR /vllm-workspace
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=cache,target=/root/.cache/ccache \
--mount=type=bind,from=vllm-build,src=/vllm-workspace/dist,target=dist \
uv pip install dist/*.whl
uv pip install dist/*.whl && \
uv pip install "vllm[audio]"
# Add labels to document build configuration
LABEL org.opencontainers.image.title="vLLM CPU"

View File

@@ -165,9 +165,9 @@ Priority is **1 = highest** (tried first).
| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. |
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.0 |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
@@ -177,7 +177,7 @@ Priority is **1 = highest** (tried first).
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `int8_per_token_head`, `fp8_per_token_head` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
>

View File

@@ -225,7 +225,7 @@ outputs = model.generate(
### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism)
Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.
Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnQuantFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.
Long term, we've added the ability to partition the graph in Inductor instead of right after Dynamo. It can be enabled with `CompilationConfig.use_inductor_graph_partition=True` but is currently experimental and only available with `torch>=2.9`. This also increases compilation time as it has to compile the whole graph and cannot reuse piecewise compilation artifacts. Once vLLM supports 2.9, we plan to make this the default approach as it will also speed up piecewise cudagraph capture.

View File

@@ -22,6 +22,7 @@ or just on the low or high end.
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
| [MLA Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | MLA Attention output → FP8/NVFP4 quant | Off by default | TBD | Yes | Always |
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low |
| [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High |
@@ -40,12 +41,13 @@ The table below lists the quantization schemes supported by each fusion on each
| ---------------------------- | ---------------------------------------- | ---------------------------------------- | ---------------------------------------- | ------------- | ---------------------------------------- |
| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — |
| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* |
| `fuse_attn_quant` (MLA)\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static(untested)\* |
| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 |
| `enable_qk_norm_rope_fusion` | FP16/BF16 | FP16/BF16 | FP16/BF16† | FP16/BF16† | — |
| `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
| `fuse_gemm_comms` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
| `fuse_norm_quant` | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | — | FP8 static, FP8 per-token, FP8 per-group |
| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static | FP8 static | — | FP8 per-group |
| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static, FP8 per-group (128/64) | FP8 static, FP8 per-group (128/64) | — | FP8 per-group |
| `fuse_act_padding` | — | — | — | — | FP16/BF16 |
\* `fuse_attn_quant` support depends on the attention backend in use; not all backends support
@@ -129,7 +131,8 @@ on SM90/SM100) and configurable via `PassConfig.fi_allreduce_fusion_max_size_mb`
explicitly. It requires the full model graph to be visible (Inductor partition or `splitting_ops=[]`).
**What it fuses.** Fuses the attention output quantization directly after the attention computation,
eliminating a full-precision memory round-trip of the attention output. Patterns covered:
eliminating a full-precision memory round-trip of the attention output. This fusion supports both
standard `Attention` and `MLAAttention` (used by DeepSeek-V2/V3/R1 models). Patterns covered:
`Attention → FP8 static quant`:
@@ -142,11 +145,24 @@ eliminating a full-precision memory round-trip of the attention output. Patterns
- `FLASHINFER`: CUDA sm100+ with FlashInfer installed
`MLAAttention → FP8 static quant` / `MLAAttention → NVFP4 dynamic quant`:
The MLA fusion operates at the graph level on the `unified_mla_attention_with_output` op and works
with all MLA decode and prefill backend combinations. Unlike standard `Attention` backends (where
the kernel writes FP8 output directly), no MLA prefill or decode backend currently supports direct
FP8/FP4 output. The fusion writes to an intermediate buffer and quantizes in a separate step, so
there is no memory round-trip elimination yet.
!!! info
The MLA attention fusion is not expected to yield a measurable speedup yet.
This will improve once MLA prefill/decode kernels support direct FP8/FP4 output.
Other attention backends do not support fused output quantization yet.
**Code locations.**
- Pass: [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py)
- Pass (Attention): [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py)
- Pass (MLAAttention): [`vllm/compilation/passes/fusion/mla_attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/mla_attn_quant_fusion.py)
- Attention backends: [`vllm/v1/attention/backends/`](https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/)
### RoPE + KV-Cache Update (`fuse_rope_kvcache`)
@@ -305,6 +321,7 @@ Note that AITER fusions are in a separate pass in `vllm.compilation.passes.fusio
Supported quantization scheme/hardware combinations:
- FP8 static per-tensor: CUDA & HIP kernel
- FP8 dynamic per-group (128/64): CUDA kernel (sm89+, not active when DeepGemm is used on sm100+)
- NVFP4 dynamic: CUDA sm100+ only with FlashInfer
- FP8 per-token-group (128): ROCm AITER only
@@ -313,6 +330,7 @@ Supported quantization scheme/hardware combinations:
- Pass: [`vllm/compilation/passes/fusion/act_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/act_quant_fusion.py)
- ROCm AITER pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py)
- CUDA/HIP kernels: [`csrc/quantization/`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/)
- Fused SiLU+Mul+BlockQuant kernel: [`csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu)
### RMSNorm + Padding (`fuse_act_padding`)

View File

@@ -244,12 +244,12 @@ response = client.chat.completions.create(
Some models, such as [Qwen3](https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html#thinking-budget), [DeepSeek](https://www.alibabacloud.com/help/en/model-studio/deep-thinking), and [Nemotron3](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16), support a thinking budget that limits the maximum number of tokens used for reasoning.
Token counting starts from `think_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `think_end_str`, effectively terminating the reasoning block.
Token counting starts from `reasoning_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `reasoning_end_str`, effectively terminating the reasoning block.
To use this feature:
- `--reasoning-parser` enables reasoning extraction.
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `think_start_str`, `think_end_str`).
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`).
- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit.
If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`.
@@ -257,20 +257,20 @@ If `thinking_token_budget` is not specified, no explicit reasoning limit is appl
`--reasoning-config` accepts a JSON object corresponding to
[ReasoningConfig][vllm.config.ReasoningConfig] with the following fields:
| Field | Type | Description |
|-------------------|----------------|--------------------------------------------------|
| `think_start_str` | `str \| null` | String that marks the start of reasoning content |
| `think_end_str` | `str \| null` | String that marks the end of reasoning content |
| Field | Type | Description |
|-----------------------|----------------|--------------------------------------------------|
| `reasoning_start_str` | `str \| null` | String that marks the start of reasoning content |
| `reasoning_end_str` | `str \| null` | String that marks the end of reasoning content |
!!! note
`think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
`reasoning_end_str` can include a transition phrase before the reasoning end token. For example, setting `reasoning_end_str` to `"I have to give the solution based on the reasoning directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
### Online Serving
```bash
vllm serve Qwen/Qwen3-0.6B \
--reasoning-parser qwen3 \
--reasoning-config '{"think_start_str": "<think>", "think_end_str": "I have to give the solution based on the thinking directly now.</think>"}'
--reasoning-config '{"reasoning_start_str": "<think>", "reasoning_end_str": "I have to give the solution based on the reasoning directly now.</think>"}'
```
Then make a request with `thinking_token_budget` to limit the reasoning tokens:
@@ -298,8 +298,8 @@ from vllm.config import ReasoningConfig
llm = LLM(
model="Qwen/Qwen3-0.6B",
reasoning_config=ReasoningConfig(
think_start_str="<think>",
think_end_str="I have to give the solution based on the thinking directly now.</think>",
reasoning_start_str="<think>",
reasoning_end_str="I have to give the solution based on the thinking directly now.</think>",
),
)

View File

@@ -481,6 +481,7 @@ th {
| `Step3p5ForCausalLM` | Step-3.5-flash | `stepfun-ai/Step-3.5-Flash`, etc. | | ✅︎ |
| `TeleChatForCausalLM` | TeleChat | `chuhac/TeleChat2-35B`, etc. | ✅︎ | ✅︎ |
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ |
| `TeleChat3ForCausalLM` | TeleChat3 | `Tele-AI/TeleChat3-36B-Thinking`, `Tele-AI/TeleChat3-Coder-36B-Thinking`, etc. | ✅︎ | ✅︎ |
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ |
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ |
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | |
@@ -541,6 +542,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | ✅︎ | ✅︎ |
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
| `CheersForConditionalGeneration` | Cheers | T + I | `ai9stars/Cheers` | | ✅︎ |
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ |
| `DeepseekVLV2ForCausalLM` | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ |
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | ✅︎ | ✅︎ |
@@ -576,7 +578,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ |
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | ✅︎ | ✅︎ |
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ |
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT, Granite Vision | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, `ibm-granite/granite-vision-3.3-2b`, etc. | | ✅︎ |
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ |
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ |
| `MiDashengLMModel` | MiDashengLM | T + A<sup>+</sup> | `mispeech/midashenglm-7b` | | ✅︎ |

View File

@@ -73,8 +73,11 @@ In addition, we have the following custom APIs:
- [Cohere Embed API](../models/pooling_models/embed.md#cohere-embed-api) (`/v2/embed`)
- Compatible with [Cohere's Embed API](https://docs.cohere.com/reference/embed)
- Works with any [embedding model](../models/pooling_models/embed.md#supported-models), including multimodal models.
- [Score API](../models/pooling_models/scoring.md#score-api) (`/score`)
- Applicable to [score models](../models/pooling_models/scoring.md).
- [Score API](../models/pooling_models/scoring.md#score-api) (`/score`, `/v1/score`)
- Applicable to [score models](../models/pooling_models/scoring.md) (cross-encoder, bi-encoder, late-interaction).
- [Generative Scoring API](#generative-scoring-api) (`/generative_scoring`)
- Applicable to [CausalLM models](../models/generative_models.md) (task `"generate"`).
- Computes next-token probabilities for specified `label_token_ids`.
- [Rerank API](../models/pooling_models/scoring.md#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
- Implements [Jina AI's v1 rerank API](https://jina.ai/reranker/)
- Also compatible with [Cohere's v1 & v2 rerank APIs](https://docs.cohere.com/v2/reference/rerank)
@@ -481,6 +484,71 @@ This approach is more robust than index-based access (`messages[0]`, `messages[1
Example template file: [examples/pooling/score/template/nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja)
### Generative Scoring API
The `/generative_scoring` endpoint uses a CausalLM model (e.g., Llama, Qwen, Mistral) to compute the probability of specified token IDs appearing as the next token. Each item (document) is concatenated with the query to form a prompt, and the model predicts how likely each label token is as the next token after that prompt. This lets you score items against a query — for example, asking "Is this the capital of France?" and scoring each city by how likely the model is to answer "Yes".
This endpoint is automatically available when the server is started with a generative model (task `"generate"`). It is separate from the pooling-based [Score API](#score-api), which uses cross-encoder, bi-encoder, or late-interaction models.
**Requirements:**
- The `label_token_ids` parameter is **required** and must contain **at least 1 token ID**.
- When 2 label tokens are provided, the score equals `P(label_token_ids[0]) / (P(label_token_ids[0]) + P(label_token_ids[1]))` (softmax over the two labels).
- When more labels are provided, the score is the softmax-normalized probability of the first label token across all label tokens.
#### Example
```bash
curl -X POST http://localhost:8000/generative_scoring \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-0.6B",
"query": "Is this city the capital of France?",
"items": ["Paris", "London", "Berlin"],
"label_token_ids": [9454, 2753]
}'
```
Here, each item is appended to the query to form prompts like `"Is this city the capital of France? Paris"`, `"... London"`, etc. The model then predicts the next token, and the score reflects the probability of "Yes" (token 9454) vs "No" (token 2753).
??? console "Response"
```json
{
"id": "generative-scoring-abc123",
"object": "list",
"created": 1234567890,
"model": "Qwen/Qwen3-0.6B",
"data": [
{"index": 0, "object": "score", "score": 0.95},
{"index": 1, "object": "score", "score": 0.12},
{"index": 2, "object": "score", "score": 0.08}
],
"usage": {"prompt_tokens": 45, "total_tokens": 48, "completion_tokens": 3}
}
```
#### How it works
1. **Prompt Construction**: For each item, builds `prompt = query + item` (or `item + query` if `item_first=true`)
2. **Forward Pass**: Runs the model on each prompt to get next-token logits
3. **Probability Extraction**: Extracts logprobs for the specified `label_token_ids`
4. **Softmax Normalization**: Applies softmax over only the label tokens (when `apply_softmax=true`)
5. **Score**: Returns the normalized probability of the first label token
#### Finding Token IDs
To find the token IDs for your labels, use the tokenizer:
```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
yes_id = tokenizer.encode("Yes", add_special_tokens=False)[0]
no_id = tokenizer.encode("No", add_special_tokens=False)[0]
print(f"Yes: {yes_id}, No: {no_id}")
```
## Ray Serve LLM
Ray Serve LLM enables scalable, production-grade serving of the vLLM engine. It integrates tightly with vLLM and extends it with features such as auto-scaling, load balancing, and back-pressure.

View File

@@ -54,5 +54,5 @@ with tempfile.TemporaryDirectory() as tmpdirname:
print("Extracted token ids:", token_ids) # Matches prompt token ids
print(
"Extracted hidden states shape:", hidden_states.shape
) # [num_hidden_layers, prompt len, hidden size]
) # [prompt len, num_hidden_layers, hidden size]
print("Extracted hidden states:", hidden_states)

View File

@@ -179,6 +179,33 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
)
# Cheers
def run_cheers(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "ai9stars/Cheers"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={modality: 1},
)
prompts = [
(
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|image_pad|>{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -2140,6 +2167,7 @@ model_example_map = {
"aria": run_aria,
"aya_vision": run_aya_vision,
"bagel": run_bagel,
"cheers": run_cheers,
"bee": run_bee,
"blip-2": run_blip2,
"chameleon": run_chameleon,

View File

@@ -16,5 +16,5 @@ flashinfer-cubin==0.6.7
nvidia-cudnn-frontend>=1.13.0,<1.19.0
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl>=4.4.0.dev1
quack-kernels>=0.2.7
nvidia-cutlass-dsl>=4.4.2
quack-kernels>=0.3.3

View File

@@ -1,5 +1,5 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28 --python-version 3.12
# uv pip compile requirements/test.in -c requirements/common.txt -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28 --python-version 3.12
absl-py==2.1.0
# via
# rouge-score
@@ -14,6 +14,7 @@ aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.3
# via
# -c requirements/common.txt
# aiohttp-cors
# datasets
# fsspec
@@ -225,7 +226,9 @@ et-xmlfile==2.0.0
evaluate==0.4.3
# via lm-eval
fastapi==0.128.0
# via gpt-oss
# via
# -c requirements/common.txt
# gpt-oss
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
@@ -234,6 +237,7 @@ fastsafetensors==0.2.2
# via -r requirements/test.in
filelock==3.16.1
# via
# -c requirements/common.txt
# blobfile
# datasets
# diffusers
@@ -505,7 +509,9 @@ mbstrdecoder==1.1.3
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.10.0
# via -r requirements/test.in
# via
# -c requirements/common.txt
# -r requirements/test.in
more-itertools==10.5.0
# via lm-eval
mpmath==1.3.0
@@ -655,13 +661,16 @@ omegaconf==2.3.0
open-clip-torch==2.32.0
# via -r requirements/test.in
openai-harmony==0.0.4
# via gpt-oss
# via
# -c requirements/common.txt
# gpt-oss
opencensus==0.11.4
# via ray
opencensus-context==0.1.3
# via opencensus
opencv-python-headless==4.13.0.90
# via
# -c requirements/common.txt
# -r requirements/test.in
# albucore
# albumentations
@@ -670,15 +679,17 @@ openpyxl==3.1.5
# via -r requirements/test.in
opentelemetry-api==1.35.0
# via
# -c requirements/common.txt
# opentelemetry-exporter-prometheus
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-exporter-prometheus==0.56b0
# via ray
opentelemetry-proto==1.36.0
opentelemetry-proto==1.35.0
# via ray
opentelemetry-sdk==1.35.0
# via
# -c requirements/common.txt
# opentelemetry-exporter-prometheus
# ray
opentelemetry-semantic-conventions==0.56b0
@@ -785,6 +796,7 @@ pqdm==0.2.0
# via -r requirements/test.in
prometheus-client==0.22.0
# via
# -c requirements/common.txt
# opentelemetry-exporter-prometheus
# ray
propcache==0.2.0
@@ -793,8 +805,9 @@ propcache==0.2.0
# yarl
proto-plus==1.26.1
# via google-api-core
protobuf==6.33.2
protobuf==6.33.6
# via
# -c requirements/common.txt
# google-api-core
# googleapis-common-protos
# grpcio-reflection
@@ -836,6 +849,7 @@ pycryptodomex==3.22.0
# via blobfile
pydantic==2.12.0
# via
# -c requirements/common.txt
# -r requirements/test.in
# albumentations
# datamodel-code-generator
@@ -973,6 +987,7 @@ regex==2024.9.11
# transformers
requests==2.32.3
# via
# -c requirements/common.txt
# azure-core
# buildkite-test-collector
# datasets
@@ -1085,6 +1100,7 @@ sentry-sdk==2.52.0
# via wandb
setuptools==77.0.3
# via
# -c requirements/common.txt
# lightning-utilities
# pytablewriter
# tensorboard
@@ -1099,6 +1115,7 @@ shellingham==1.5.4
# typer
six==1.16.0
# via
# -c requirements/common.txt
# junit-xml
# lightly
# opencensus
@@ -1183,6 +1200,7 @@ tifffile==2025.3.30
# terratorch
tiktoken==0.12.0
# via
# -c requirements/common.txt
# gpt-oss
# lm-eval
# mistral-common
@@ -1195,6 +1213,7 @@ timm==1.0.17
# torchgeo
tokenizers==0.22.0
# via
# -c requirements/common.txt
# -r requirements/test.in
# transformers
tomli==2.2.1
@@ -1271,6 +1290,7 @@ tqdm==4.67.3
# transformers
transformers==4.57.5
# via
# -c requirements/common.txt
# -r requirements/test.in
# genai-perf
# peft
@@ -1297,6 +1317,7 @@ typeshed-client==2.8.2
# via jsonargparse
typing-extensions==4.15.0
# via
# -c requirements/common.txt
# aiosignal
# albumentations
# alembic

View File

@@ -1063,7 +1063,10 @@ setup(
# Optional deps for AMD FP4 quantization support
"petit-kernel": ["petit-kernel"],
# Optional deps for Helion kernel development
"helion": ["helion==0.3.2"],
# NOTE: When updating helion version, also update CI files:
# - .buildkite/test_areas/kernels.yaml
# - .buildkite/test-amd.yaml
"helion": ["helion==0.3.3"],
# Optional deps for gRPC server (vllm serve --grpc)
"grpc": ["smg-grpc-servicer[vllm] >= 0.5.0"],
# Optional deps for OpenTelemetry tracing

View File

@@ -8,7 +8,7 @@ from copy import deepcopy
import depyf
from torch import fx
from torch._ops import OpOverload
from torch._ops import OpOverload, OpOverloadPacket
from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.passes.fx_utils import find_op_nodes
@@ -90,7 +90,9 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph
self.final_graph = graph
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
def check_before_ops(
self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True
):
for op in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
@@ -99,13 +101,19 @@ class TestBackend:
if fully_replaced:
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]):
def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]):
for op in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
def op_count(self, op: OpOverload, before=False) -> int:
def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int:
graph = self.graph_pre_pass if before else self.graph_post_pass
return len(list(find_op_nodes(op, graph)))
def print_graphs(self):
print("=== Graph before custom passes ===")
print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src)
print("=== Graph after custom passes ===")
print(self.graph_post_pass.python_code(root_module="self", verbose=True).src)

View File

@@ -170,14 +170,3 @@ class TestFullCUDAGraph:
piecewise_res.outputs[0].text.lower()
== full_res.outputs[0].text.lower()
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
# Flex_Attention is not supported with full cuda graph
with pytest.raises(RuntimeError):
LLM(
model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
attention_config={"backend": "FLEX_ATTENTION"},
)

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from collections import defaultdict
import pytest
import regex as re
@@ -52,6 +53,16 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
)
# Fetch match table from each worker via RPC and sum across workers.
worker_tables = llm.llm_engine.engine_core.collective_rpc(
"get_compilation_match_table"
)
combined: defaultdict[str, int] = defaultdict(int)
for table in worker_tables:
for k, v in table.items():
combined[k] += v
return dict(combined)
@pytest.fixture
def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
@@ -73,10 +84,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
rocm_aiter_ops.refresh_env_variables()
# Filter here to reduce code duplication
backend_name = attn_backend.backend.name.lower()
requires_mla = "deepseek" in model_name.lower()
is_mla = "mla" in attn_backend.backend.name.lower()
is_mla = "mla" in backend_name
# DeepSeek V3.2 uses sparse MLA
requires_sparse = "v3.2" in model_name.lower()
is_sparse = "sparse" in backend_name
if requires_mla != is_mla:
if requires_mla != is_mla or requires_sparse != is_sparse:
pytest.skip(
f"Incompatible model '{model_name}' and "
f"attention backend '{attn_backend.backend.name}'"
@@ -113,7 +128,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(full_compilation_config, model_name, **model_kwargs)
match_table = run_model(full_compilation_config, model_name, **model_kwargs)
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2, 3]
@@ -155,11 +170,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
else:
num_ranges_activated = num_compile_ranges
# TODO: Remove log counting in unit tests
# once all matchers implement VllmFusionPatternMatcherPass
n_expected = tp_size * num_ranges_activated
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
if match_name != "attn_quant_fusion":
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
expected_matches = getattr(matches, match_name)
@@ -215,6 +233,15 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"{tp_size * (num_ranges_activated - 1)} large-range "
f"entries (SP took precedence), found: {log_matches}"
)
elif match_name == "attn_quant_fusion":
actual_match = match_table.get(
"attn_quant_fusion", 0
) + match_table.get("mla_attn_quant_fusion", 0)
assert actual_match == expected_matches * n_expected, (
f"Could not find {expected_matches * n_expected} "
f"{match_name} (found {actual_match})."
)
else:
expected_matches_list = [expected_matches] * n_expected
assert sorted(log_matches) == expected_matches_list, (

View File

@@ -58,6 +58,15 @@ TRITON_MLA_ATTN = pytest.param(
id="TRITON_MLA",
)
FLASHMLA_SPARSE_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.FLASHMLA_SPARSE),
id="FLASHMLA_SPARSE",
marks=pytest.mark.skipif(
not is_blackwell(),
reason="FlashMLA Sparse requires Blackwell",
),
)
# Models
llama3_8b = ModelFusionInfo(
model_name="meta-llama/Llama-3.1-8B-Instruct",
@@ -141,6 +150,18 @@ qwen3_a3b_fp8 = ModelFusionInfo(
),
)
deepseek_coder_v2_lite_fp8 = ModelFusionInfo(
model_name="RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8",
matches=lambda n_layers: Matches(
# first_k_dense_replace=1; MoE hides most rms+quant sites
rms_quant_fusion=1,
act_quant_fusion=min(1, n_layers), # dense layers only
# MLA attn + static FP8 quant
attn_quant_fusion=n_layers,
ar_rms_fusion=n_layers * 2 + 1,
),
)
deepseek_v3_fp8 = ModelFusionInfo(
model_name="deepseek-ai/DeepSeek-V3",
matches=lambda n_layers: Matches(
@@ -150,10 +171,9 @@ deepseek_v3_fp8 = ModelFusionInfo(
# - post_attn_layernorm + MLP
# 2 per MoE layer (remaining) due to MoE wrapping
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
# TODO silu+block quant
# act_quant_fusion=min(3, n_layers), # dense layers only
act_quant_fusion=0,
# MLA attn + quant not supported yet:
# silu+block quant
act_quant_fusion=min(3, n_layers), # dense layers only
# MLA attn + per-group FP8 quant not supported yet:
# https://github.com/vllm-project/vllm/issues/35792
attn_quant_fusion=0,
ar_rms_fusion=n_layers * 2 + 1,
@@ -163,6 +183,16 @@ deepseek_v3_fp8 = ModelFusionInfo(
),
)
deepseek_v32_fp4 = ModelFusionInfo(
model_name="nvidia/DeepSeek-V3.2-NVFP4",
matches=lambda n_layers: Matches(
rms_quant_fusion=0,
act_quant_fusion=0,
attn_quant_fusion=n_layers,
ar_rms_fusion=n_layers * 2 + 1,
),
)
gpt_oss_20b = ModelFusionInfo(
model_name="openai/gpt-oss-20b",
matches=lambda n_layers: Matches(

View File

@@ -18,11 +18,14 @@ from .common import (
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
FLASHMLA_SPARSE_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
TRITON_MLA_ATTN,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
deepseek_v32_fp4,
llama3_8b_fp4,
llama3_8b_fp8,
llama4_scout_fp4,
@@ -37,6 +40,7 @@ from .models import (
(*llama3_8b_fp8, False),
(*qwen3_a3b_fp8, False),
(*qwen3_a3b_fp8, True),
(*deepseek_coder_v2_lite_fp8, False),
(*deepseek_v3_fp8, False),
(*deepseek_v3_fp8, True),
pytest.param(
@@ -99,6 +103,8 @@ def test_tp1_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
@@ -142,9 +148,12 @@ def test_tp1_fp8_fusions(
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp4, llama4_scout_fp4],
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4],
)
@pytest.mark.parametrize(
"attn_backend",
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN],
)
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [6])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@@ -166,6 +175,7 @@ def test_tp1_fp4_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,

View File

@@ -18,8 +18,11 @@ from .common import (
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
FLASHMLA_SPARSE_ATTN,
TRITON_ATTN,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
deepseek_v32_fp4,
gpt_oss_20b,
llama3_8b,
llama3_8b_fp4,
@@ -37,7 +40,13 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
# qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8],
[
llama3_8b_fp8,
llama4_scout_fp8,
qwen3_a3b_fp8,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
],
)
@pytest.mark.parametrize(
"attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN]
@@ -68,6 +77,7 @@ def test_tp2_ar_rms_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
@@ -103,9 +113,12 @@ def test_tp2_ar_rms_fp8_fusions(
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp4, llama4_scout_fp4],
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4],
)
@pytest.mark.parametrize(
"attn_backend",
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN],
)
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@@ -128,6 +141,7 @@ def test_tp2_ar_rms_fp4_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
@@ -182,6 +196,7 @@ def test_tp2_ar_rms_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,

View File

@@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
@@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,

View File

@@ -9,7 +9,6 @@ from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, multi_gpu_test
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.passes.fx_utils import find_auto_fn
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
@@ -86,13 +85,14 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
]
def ops_in_model(self):
if RMSNorm.enabled():
return [
torch.ops._C.rms_norm.default,
return (
[torch.ops.vllm_ir.rms_norm]
+ [
torch.ops._C.fused_add_rms_norm.default,
]
else:
return []
if RMSNorm.enabled()
else []
)
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
@@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model(
assert backend.op_count(op, before=False) == 4
for op in model.ops_in_model():
find_auto_fn(backend.graph_post_pass.nodes, op)
assert backend.op_count(op, before=False) > 0

View File

View File

@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from torch import nn
import vllm.kernels # noqa: F401 to register kernels
from vllm import ir
from vllm.compilation.passes.ir.lowering_pass import (
VllmIRLoweringPass,
)
from vllm.config import get_current_vllm_config
from vllm.ir import ops
from vllm.platforms import current_platform
from ...backend import TestBackend
class Model(nn.Module):
def __init__(self, hidden_size=16, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_size = hidden_size
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
def forward(self, x):
x1 = x + 4.0
x2 = ops.rms_norm(x1, self.weight, 1e-5)
x3 = x2 * 5.0
# no weight
x4 = ops.rms_norm(x3, None, 1e-5)
x5 = x4 / 2.0
# dispatch to native due to variance_size parameter
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
return x6 + 3.0
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
def test_lowering_rms_norm(rms_provider, default_vllm_config):
torch.set_default_device(current_platform.device_type)
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
backend = TestBackend(lowering_pass)
backend_unlowered = TestBackend()
model = Model()
x = torch.randn(8, 16, dtype=torch.bfloat16)
with (
ops.rms_norm.set_priority([rms_provider, "native"]),
ir.enable_torch_wrap(True),
):
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
compiled_unlowered_model = torch.compile(
model, backend=backend_unlowered, fullgraph=True
)
output = compiled_model(x)
output_unlowered = compiled_unlowered_model(x)
selected = lowering_pass.selected_impls["rms_norm"]
assert len(selected) == 3
assert selected["rms_norm"] == rms_provider
assert selected["rms_norm_1"] == rms_provider
assert selected["rms_norm_2"] == "native"
# Compiled function guards on global value, avoid recompilation
with ir.enable_torch_wrap(True):
output2 = compiled_model(x)
torch.testing.assert_close(output_unlowered, output)
torch.testing.assert_close(output_unlowered, output2)

View File

@@ -6,6 +6,7 @@ import pytest
import torch
import vllm.config
import vllm.ir.ops
import vllm.plugins
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
@@ -51,7 +52,6 @@ from vllm.utils.deep_gemm import (
FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Kernel and group_shape combinations: (kernel, group_shape)
@@ -246,10 +246,8 @@ class TestModel(torch.nn.Module):
]
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
return [torch.ops.vllm_ir.rms_norm] + (
[RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
)
@@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant(
),
)
with vllm.config.set_current_vllm_config(vllm_config):
with (
vllm.config.set_current_vllm_config(vllm_config),
vllm_config.kernel_config.ir_op_priority.set_priority(),
):
# Setup device before model creation
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
@@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant(
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
# rms_norm is IR, not included
# 6 = 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 6
assert n_add_nodes(backend.graph_post_pass) == 2

View File

@@ -9,7 +9,10 @@ from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import TestFP8Layer, flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.attn_quant_fusion import ATTN_OP, AttnFusionPass
from vllm.compilation.passes.fusion.attn_quant_fusion import (
ATTN_OP,
AttnQuantFusionPass,
)
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
@@ -384,7 +387,7 @@ def test_attention_quant_pattern(
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
attn_pass = LazyInitPass(AttnQuantFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
@@ -434,7 +437,7 @@ def test_attention_quant_pattern(
# Only output quant ops are fused into attention.
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
# access the underlying `AttnQuantFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion

View File

@@ -0,0 +1,508 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest
import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import TestFP8Layer, flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.mla_attn_quant_fusion import (
MLA_ATTN_OP,
MLAAttnQuantFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.config import (
AttentionConfig,
CacheConfig,
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.attention import MLAAttention
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4Config
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import MLAAttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
class MLAAttentionQuantPatternModel(torch.nn.Module):
"""Base model for MLA AttentionQuantPattern fusion."""
def __init__(
self,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
kv_cache_dtype: torch.dtype,
device: torch.device,
vllm_config: VllmConfig,
**kwargs,
):
super().__init__()
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.kv_lora_rank = kv_lora_rank
self.output_dim = num_heads * v_head_dim
self.head_size = kv_lora_rank + qk_rope_head_dim
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
# Create kv_b_proj (ColumnParallelLinear) on device.
# Reuse weights from prior model instance when available, because
# ColumnParallelLinear may get NaN from recycled CUDA memory after
# torch.compile runs in the same process.
kv_b_proj = ColumnParallelLinear(
input_size=kv_lora_rank,
output_size=num_heads * (qk_nope_head_dim + v_head_dim),
bias=False,
prefix="model.layers.0.self_attn.kv_b_proj",
).to(device)
kv_b_proj_weight = kwargs.get("kv_b_proj_weight")
if kv_b_proj_weight is not None:
kv_b_proj.weight.data.copy_(kv_b_proj_weight)
elif kv_b_proj.weight.data.isnan().any():
# Sanitize NaN from recycled CUDA memory
kv_b_proj.weight.data.normal_()
# Create MLAAttention
self.mla_attn = MLAAttention(
num_heads=num_heads,
scale=1.0 / (self.qk_head_dim**0.5),
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
kv_b_proj=kv_b_proj,
cache_config=vllm_config.cache_config,
quant_config=self.quant_config,
prefix="model.layers.0.self_attn.attn",
)
self.mla_attn._k_scale = self.mla_attn._k_scale.to(device)
self.mla_attn._v_scale = self.mla_attn._v_scale.to(device)
# Initialize W_UK_T and W_UV from kv_b_proj weights
self.mla_attn.process_weights_after_loading(torch.get_default_dtype())
self.kv_b_proj_weight = kv_b_proj.weight.data.clone()
self.block_size = 16
# Initialize MLA MetadataBuilder
self.builder = self.mla_attn.attn_backend.get_builder_cls()(
kv_cache_spec=MLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=self.kv_cache_dtype,
),
layer_names=[self.mla_attn.layer_name],
vllm_config=self.vllm_config,
device=self.device,
)
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
"""Initialize MLA attention metadata.
NOTE: Uses decode-only batch (query_len=1 per request). The prefill
(forward_mha) path is not separately tested here because it requires
FlashAttention availability and different input tensor shapes. The
quant logic in forward_impl is identical for both paths — it quantizes
the full output[:num_actual_toks] buffer after both forward_mha and
forward_mqa have written their results.
"""
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
batch_spec, self.block_size, self.device, arange_block_indices=True
)
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
# MLA KV cache is 3D: (num_blocks, block_size, head_size)
attn_backend = self.mla_attn.attn_backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, 1, self.head_size
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
ordered_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
raw_tensor = torch.zeros(
ordered_shape, dtype=self.kv_cache_dtype, device=self.device
)
kv_cache = raw_tensor.permute(*inv_order)
self.mla_attn.kv_cache = kv_cache
self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
return self.attn_metadata
class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
"""Test model for MLA Attention + FP8 static quant fusion."""
quant_key = kFp8StaticTensorSym
quant_config = Fp8Config()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp8_linear = TestFP8Layer(
weight_shape=(self.output_dim, self.output_dim),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
)
w = kwargs.get("w")
if w is not None:
self.fp8_linear.weight = w["weight"]
self.fp8_linear.weight_scale = w["wscale"]
self.fp8_linear.input_scale = w["scale"]
self.w = {
"weight": self.fp8_linear.weight,
"wscale": self.fp8_linear.weight_scale,
"scale": self.fp8_linear.input_scale,
}
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
):
"""Forward pass that creates the MLA attention + FP8 quant pattern."""
attn_output = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(q.shape[0], self.output_dim),
)
return self.fp8_linear(attn_output)
class TestMLAAttentionNvfp4QuantPatternModel(MLAAttentionQuantPatternModel):
"""Test model for MLA Attention + NVFP4 quant fusion."""
quant_key = kNvfp4Dynamic
quant_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=False,
kv_cache_quant_algo=None,
exclude_modules=[],
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.w = kwargs.get(
"w",
{
"weight": torch.randint(
256,
(self.output_dim, self.output_dim // 2),
dtype=FP4_DTYPE,
device=self.device,
),
"wscale_swizzled": torch.randn(
self.output_dim, self.output_dim // 16
).to(dtype=FP8_DTYPE, device=self.device),
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
},
)
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
):
"""Forward pass that creates the MLA attention + NVFP4 quant pattern."""
attn_output = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(q.shape[0], self.output_dim),
)
quant_output, output_block_scale = scaled_fp4_quant(
attn_output, 1 / self.w["scale"]
)
return cutlass_scaled_fp4_mm(
a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype,
)
def is_nvfp4_supported():
return current_platform.has_device_capability(100)
# MLA test configuration
MLA_DIMS: list[tuple[int, int, int, int, int]] = []
PATTERN_TEST_MODELS_MLA_FP8: list[tuple[str, type]] = []
PATTERN_TEST_MODELS_MLA_FP4: list[tuple[str, type]] = []
BACKENDS_MLA_FP8: list[AttentionBackendEnum] = []
BACKENDS_MLA_FP4: list[AttentionBackendEnum] = []
if current_platform.is_cuda():
# (num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank)
MLA_DIMS = [(16, 128, 64, 128, 512)]
PATTERN_TEST_MODELS_MLA_FP8 = [
(
"deepseek-ai/DeepSeek-V2-Lite",
TestMLAAttentionFp8StaticQuantPatternModel,
)
]
PATTERN_TEST_MODELS_MLA_FP4 = [
(
"deepseek-ai/DeepSeek-V2-Lite",
TestMLAAttentionNvfp4QuantPatternModel,
)
]
BACKENDS_MLA_FP8 = [AttentionBackendEnum.TRITON_MLA]
BACKENDS_MLA_FP4 = [AttentionBackendEnum.TRITON_MLA]
@pytest.mark.parametrize(
"num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank",
MLA_DIMS,
)
@pytest.mark.parametrize("batch_size", [7, 256] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize(
"backend, model_name, model_class, custom_ops",
list(
flat_product(
BACKENDS_MLA_FP8,
PATTERN_TEST_MODELS_MLA_FP8,
["+quant_fp8", "-quant_fp8"],
)
)
+ list(flat_product(BACKENDS_MLA_FP4, PATTERN_TEST_MODELS_MLA_FP4, [""])),
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
def test_mla_attention_quant_pattern(
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
batch_size: int,
dtype: torch.dtype,
custom_ops: str,
model_name: str,
model_class: type[MLAAttentionQuantPatternModel],
backend: AttentionBackendEnum,
dist_init,
monkeypatch,
use_fresh_inductor_cache,
):
"""Test MLA AttentionQuantPattern fusion pass"""
if (
model_class is TestMLAAttentionNvfp4QuantPatternModel
and not is_nvfp4_supported()
):
pytest.skip("NVFP4 is not supported on this GPU (requires SM 100+).")
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.manual_seed(42)
model_config = ModelConfig(
model=model_name,
max_model_len=2048,
dtype=dtype,
)
vllm_config = VllmConfig(
model_config=model_config,
scheduler_config=SchedulerConfig(
max_num_seqs=1024,
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops_list,
),
cache_config=CacheConfig(cache_dtype="auto"),
attention_config=AttentionConfig(backend=backend),
)
# MLA inputs: q(B, N, qk_head_dim), kv_c_normed(B, L), k_pe(B, 1, R)
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
q = torch.randn(batch_size, num_heads, qk_head_dim, dtype=dtype, device=device)
kv_c_normed = torch.randn(batch_size, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(batch_size, 1, qk_rope_head_dim, dtype=dtype, device=device)
# Mark first dimension as dynamic
torch._dynamo.mark_dynamic(q, 0)
torch._dynamo.mark_dynamic(kv_c_normed, 0)
torch._dynamo.mark_dynamic(k_pe, 0)
# Run model without fusion
vllm_config_unfused = copy.deepcopy(vllm_config)
with (
set_current_vllm_config(vllm_config_unfused),
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
):
model_unfused = model_class(
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
kv_cache_dtype=dtype,
device=device,
vllm_config=vllm_config_unfused,
)
model_unfused = model_unfused.to(device)
# HACK: See #131044
result_unfused_0 = model_unfused(q, kv_c_normed, k_pe) # noqa: F841
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
compiled_unfused = torch.compile(model_unfused, fullgraph=True)
result_unfused = compiled_unfused(q, kv_c_normed, k_pe)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
fuse_attn_quant=True, eliminate_noops=True
)
with (
set_current_vllm_config(vllm_config),
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
):
model_fused = model_class(
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
kv_cache_dtype=dtype,
device=device,
vllm_config=vllm_config,
w=model_unfused.w,
kv_b_proj_weight=model_unfused.kv_b_proj_weight,
)
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
# Create test backend with fusion passes
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = LazyInitPass(MLAAttnQuantFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# HACK: See https://github.com/vllm-project/vllm/issues/31044
result_fused_0 = model_fused(q, kv_c_normed, k_pe) # noqa: F841
compiled_fused = torch.compile(
model_fused, backend=test_backend, fullgraph=True
)
result_fused = compiled_fused(q, kv_c_normed, k_pe)
# Check attn fusion support
quant_key: QuantKey = model_class.quant_key
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, MLAAttention)
]
assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
"All MLA layers should support attention fusion"
)
# Check quantization ops in the graph
quant_op = (
torch.ops.aten.reciprocal
if "-quant_fp8" in custom_ops_list
else QUANT_OPS[quant_key]
)
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check MLA attention ops in the graph
attn_nodes_pre = list(find_op_nodes(MLA_ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(MLA_ATTN_OP, test_backend.graph_post_pass))
assert len(attn_nodes_pre) > 0, "Should have MLA attention nodes before fusion"
assert len(attn_nodes_pre) == len(attn_nodes_post), (
"Should have same number of MLA attention nodes before and after fusion"
)
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
"MLA attention should not have output_scale before fusion"
)
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
"MLA attention should have output_scale after fusion"
)
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
"MLA attention should not have output_block_scale before fusion"
)
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
"MLA attention should not have output_block_scale after FP8 fusion"
)
elif quant_key.dtype == FP4_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
"MLA attention should have output_block_scale after FP4 fusion"
)
# Check numerical correctness
torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)

View File

@@ -3,11 +3,11 @@
import pytest
import torch
from torch._ops import OpOverload, OpOverloadPacket
from tests.compile.backend import TestBackend
from vllm.compilation.passes.fusion.matcher_utils import (
FLASHINFER_ROTARY_OP,
RMS_OP,
ROTARY_OP,
)
from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
@@ -100,13 +100,8 @@ class QKNormRoPETestModel(torch.nn.Module):
q, k = self.rotary_emb(positions, q, k)
return q, k, v
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops = []
if self.enable_rms_norm_custom_op:
ops.append(RMS_OP)
else:
ops.append(RSQRT_OP)
def ops_in_model_before(self) -> list[OpOverload | OpOverloadPacket]:
ops: list[OpOverload | OpOverloadPacket] = [torch.ops.vllm_ir.rms_norm]
if self.enable_rope_custom_op:
if self.rotary_emb.use_flashinfer:
ops.append(FLASHINFER_ROTARY_OP)
@@ -116,7 +111,7 @@ class QKNormRoPETestModel(torch.nn.Module):
ops.append(INDEX_SELECT_OP)
return ops
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
def ops_in_model_after(self) -> list[OpOverload | OpOverloadPacket]:
return [FUSED_QK_ROPE_OP]
@@ -166,7 +161,10 @@ def test_qk_norm_rope_fusion(
num_heads, num_kv_heads, head_dim = 16, 4, 128
T = 5
with set_current_vllm_config(vllm_config):
with (
set_current_vllm_config(vllm_config),
vllm_config.kernel_config.ir_op_priority.set_priority(),
):
model = QKNormRoPETestModel(
num_heads=num_heads,
num_kv_heads=num_kv_heads,

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from functools import partial
import pytest
import torch
@@ -34,13 +35,16 @@ from vllm.model_executor.kernels.linear import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
@@ -165,6 +169,48 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
class TestSiluMulBlockQuantModel(torch.nn.Module):
quant_key = kFp8Dynamic128Sym
def __init__(self, hidden_size: int, is_scale_transposed: bool = False, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.is_scale_transposed = is_scale_transposed
self.quant_fp8 = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
column_major_scales=is_scale_transposed,
compile_native=False,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.quant_fp8.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
out, scale = self.quant_fp8(y)
group_size = self.quant_key.scale.group_shape[1]
scale_expanded = scale.repeat_interleave(group_size, dim=1)
dequant = out.to(dtype=torch.float32) * scale_expanded
return (dequant,)
def ops_in_model_before(self):
ops = []
if self.enable_silu_mul_custom_op:
ops.append(SILU_MUL_OP)
# When silu custom op is disabled, aten.mul.Tensor also appears
# in dequant code, so we skip checking it to avoid false positives.
ops.append(
QUANT_OPS[self.quant_key]
if self.enable_quant_fp8_custom_op
else torch.ops.aten.reciprocal.default
)
return ops
def ops_in_model_after(self):
return [FUSED_OPS[self.quant_key]]
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
CUDA_KERNELS = [
FlashInferFP8ScaledMMLinearKernel,
@@ -200,6 +246,19 @@ TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
not current_platform.is_rocm(), reason="ROCm only"
),
),
# Block quant fusion for per-group FP8 (CUDA only).
*[
pytest.param(
partial(TestSiluMulBlockQuantModel, is_scale_transposed=transposed),
True,
None,
marks=pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUDA only"
),
id=f"TestSiluMulBlockQuant-transposed={transposed}",
)
for transposed in [False, True]
],
],
)
@pytest.mark.skipif(
@@ -213,6 +272,7 @@ def test_fusion_silu_and_mul_quant(
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
| TestSiluMulBlockQuantModel
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
@@ -223,6 +283,12 @@ def test_fusion_silu_and_mul_quant(
pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
if (
isinstance(model_class, partial)
and model_class.func is TestSiluMulBlockQuantModel
and is_deep_gemm_supported()
):
pytest.skip("SiluMul+BlockQuant fusion not applicable with DeepGemm")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
@@ -269,11 +335,13 @@ def test_fusion_silu_and_mul_quant(
result2 = model2(x)
# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
if isinstance(model, TestSiluMulFp8QuantModel):
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
elif isinstance(model, TestSiluMulNvfp4QuantModel):
atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
elif isinstance(
model, (TestSiluMulGroupFp8QuantModel, TestSiluMulBlockQuantModel)
):
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(

View File

@@ -1622,3 +1622,26 @@ def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache):
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.fixture(scope="function")
def disable_log_dedup(monkeypatch):
"""
Disable log deduplication such that warning_once and info_once always print.
"""
# Patch logger._print_warning_once to remove the lru_cache decorator
from vllm import logger
original_print_warning_once = logger._print_warning_once
original_print_info_once = logger._print_info_once
original_print_debug_once = logger._print_debug_once
logger._print_warning_once = original_print_warning_once.__wrapped__
logger._print_info_once = original_print_info_once.__wrapped__
logger._print_debug_once = original_print_debug_once.__wrapped__
yield
logger._print_warning_once = original_print_warning_once
logger._print_info_once = original_print_info_once
logger._print_debug_once = original_print_debug_once

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random
import msgspec
@@ -166,3 +167,31 @@ class MockSubscriber:
self.sub.close()
for replay in self.replay_sockets:
replay.close()
@pytest.fixture
def enable_ray_v2_backend():
"""Set env vars for the Ray V2 executor backend and shut down Ray
between tests."""
import ray
saved = {
"VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get(
"VLLM_USE_RAY_V2_EXECUTOR_BACKEND"
),
"VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get(
"VLLM_ENABLE_V1_MULTIPROCESSING"
),
}
os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
if ray.is_initialized():
ray.shutdown()
try:
yield
finally:
if ray.is_initialized():
ray.shutdown()
os.environ.update({k: v for k, v in saved.items() if v is not None})
for key in (k for k, v in saved.items() if v is None):
os.environ.pop(key, None)

View File

@@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Multi-node integration test for MessageQueue TCP fallback.
Verifies that when writer and readers span separate nodes (Docker containers
with isolated /dev/shm), `create_from_process_group` correctly detects
cross-node ranks via `in_the_same_node_as()` and falls back to ZMQ TCP
transport — and that data actually arrives.
"""
import numpy as np
import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.parallel_state import in_the_same_node_as
def main():
dist.init_process_group(backend="gloo")
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size >= 2, (
f"Need at least 2 ranks across nodes, got world_size={world_size}"
)
# Verify that in_the_same_node_as detects cross-node correctly
status = in_the_same_node_as(dist.group.WORLD, source_rank=0)
local_count = sum(status)
print(
f"[Rank {rank}] in_the_same_node_as(source=0): {status} "
f"(local={local_count}/{world_size})"
)
# With 2 Docker containers (1 proc each), rank 0 and rank 1
# should be on different nodes.
assert local_count < world_size, (
f"Expected cross-node ranks but all {world_size} ranks appear local."
)
# Create MessageQueue
writer_rank = 0
mq = MessageQueue.create_from_process_group(
dist.group.WORLD,
max_chunk_bytes=1024 * 1024, # 1 MiB
max_chunks=10,
writer_rank=writer_rank,
)
# Verify the transport path selection
if rank == writer_rank:
print(
f"[Rank {rank}] Writer: n_local_reader={mq.n_local_reader}, "
f"n_remote_reader={mq.n_remote_reader}"
)
assert mq.n_remote_reader > 0, (
"Writer should have at least 1 remote (TCP) reader in a multi-node setup."
)
else:
if status[rank]:
assert mq._is_local_reader, (
f"Rank {rank} is on the same node as writer but is not a local reader."
)
print(f"[Rank {rank}] Reader: local (shared memory)")
else:
assert mq._is_remote_reader, (
f"Rank {rank} is on a different node but is not a remote (TCP) reader."
)
print(f"[Rank {rank}] Reader: remote (TCP)")
# Test data transfer: simple objects
dist.barrier()
if rank == writer_rank:
mq.enqueue("hello_from_node0")
else:
msg = mq.dequeue(timeout=10)
assert msg == "hello_from_node0"
dist.barrier()
print(f"[Rank {rank}] Simple object test passed")
# Test data transfer: numpy arrays
np.random.seed(42)
arrays = [
np.random.randint(0, 100, size=np.random.randint(100, 5000)) for _ in range(100)
]
dist.barrier()
if rank == writer_rank:
for arr in arrays:
mq.enqueue(arr)
else:
for i, expected in enumerate(arrays):
received = mq.dequeue(timeout=10)
assert np.array_equal(expected, received), (
f"Array mismatch at index {i}: "
f"expected shape {expected.shape}, got shape {received.shape}"
)
dist.barrier()
print(f"[Rank {rank}] Numpy array test passed")
# Test data transfer: large payload (> max_chunk_bytes)
dist.barrier()
big_array = np.zeros(200_000, dtype=np.int64) # ~1.6 MiB > 1 MiB chunk
if rank == writer_rank:
mq.enqueue(big_array)
else:
received = mq.dequeue(timeout=10)
assert np.array_equal(big_array, received)
dist.barrier()
print(f"[Rank {rank}] Large payload test passed")
# Done -- cleanup
dist.barrier()
print(f"[Rank {rank}] All MessageQueue TCP multi-node tests passed!")
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,345 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Integration tests for RayExecutorV2 at the executor level.
Validates executor initialization, placement group support, RPC calls,
and distributed execution with various TP/PP configurations.
"""
import gc
import threading
from unittest.mock import patch
import pytest
import ray
from vllm import LLM
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.v1.executor.ray_executor_v2 import RayExecutorV2
pytestmark = pytest.mark.usefixtures("enable_ray_v2_backend")
MODEL = "facebook/opt-125m"
def create_vllm_config(
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
max_model_len: int = 256,
gpu_memory_utilization: float = 0.3,
placement_group=None,
) -> VllmConfig:
engine_args = EngineArgs(
model=MODEL,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
distributed_executor_backend="ray",
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
if placement_group is not None:
vllm_config.parallel_config.placement_group = placement_group
return vllm_config
def ensure_ray_initialized():
if not ray.is_initialized():
ray.init(ignore_reinit_error=True)
@pytest.fixture
def create_placement_group(request):
ensure_ray_initialized()
num_gpus = request.param
bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)]
pg = ray.util.placement_group(bundles, strategy="PACK")
ray.get(pg.ready())
yield pg
ray.util.remove_placement_group(pg)
@pytest.fixture
def executor(request):
"""Create a RayExecutorV2 and shut it down after the test."""
executor = RayExecutorV2(vllm_config=request.param)
yield executor
executor.shutdown()
def assert_executor(executor, tp_size, pp_size):
"""Common assertions for executor initialization tests."""
world_size = tp_size * pp_size
expected_output_rank = (pp_size - 1) * tp_size
assert executor.world_size == world_size
assert len(executor.ray_worker_handles) == world_size
assert len(executor.response_mqs) == world_size
assert executor._get_output_rank() == expected_output_rank
if pp_size > 1:
assert executor.max_concurrent_batches == pp_size
executor.check_health()
assert not executor.is_failed
ranks = sorted(h.rank for h in executor.ray_worker_handles)
assert ranks == list(range(world_size))
for handle in executor.ray_worker_handles:
assert handle.node_id is not None
@pytest.mark.parametrize("tp_size, pp_size", [(1, 1), (2, 1), (4, 1), (2, 2)])
def test_ray_v2_executor(tp_size, pp_size):
"""Validate RayExecutorV2 with various TP/PP configs."""
vllm_config = create_vllm_config(
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
)
executor = RayExecutorV2(vllm_config=vllm_config)
try:
assert_executor(executor, tp_size, pp_size)
finally:
executor.shutdown()
@pytest.mark.parametrize(
"tp_size, pp_size, create_placement_group",
[(2, 1, 2), (4, 1, 4), (2, 2, 4)],
indirect=["create_placement_group"],
)
def test_ray_v2_executor_pg(tp_size, pp_size, create_placement_group):
"""Validate RayExecutorV2 with various TP/PP configs using external PG."""
vllm_config = create_vllm_config(
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
placement_group=create_placement_group,
)
executor = RayExecutorV2(vllm_config=vllm_config)
try:
assert_executor(executor, tp_size, pp_size)
finally:
executor.shutdown()
@pytest.mark.parametrize(
"executor",
[create_vllm_config(tensor_parallel_size=2)],
indirect=True,
)
def test_ray_v2_executor_failure_callback(executor):
"""Validate failure callback registration."""
callback_invoked = False
def test_callback():
nonlocal callback_invoked
callback_invoked = True
executor.register_failure_callback(test_callback)
assert not callback_invoked
executor.is_failed = True
executor.register_failure_callback(test_callback)
assert callback_invoked
@pytest.mark.parametrize(
"executor",
[create_vllm_config(tensor_parallel_size=2)],
indirect=True,
)
def test_ray_v2_executor_collective_rpc(executor):
"""Validate collective RPC calls through MessageQueue."""
executor.check_health()
assert not executor.is_failed
assert executor.rpc_broadcast_mq is not None
@pytest.mark.parametrize(
"executor",
[create_vllm_config(tensor_parallel_size=2)],
indirect=True,
)
def test_ray_v2_executor_driver_node_rank_0(executor):
"""Validate that driver node workers get the lowest ranks."""
driver_node = ray.get_runtime_context().get_node_id()
for handle in executor.ray_worker_handles:
assert handle.node_id == driver_node
rank0_handle = next(h for h in executor.ray_worker_handles if h.rank == 0)
assert rank0_handle.node_id == driver_node
@pytest.mark.parametrize(
"executor",
[create_vllm_config(tensor_parallel_size=2)],
indirect=True,
)
def test_ray_v2_executor_worker_death(executor):
"""Validate executor detects worker death via ray.wait()."""
callback_event = threading.Event()
def on_failure():
callback_event.set()
executor.register_failure_callback(on_failure)
assert not executor.is_failed
# Kill one worker actor externally
victim = executor.ray_worker_handles[1].actor
ray.kill(victim, no_restart=True)
# Monitor thread should detect the death and invoke callback
assert callback_event.wait(timeout=30)
assert executor.is_failed
assert executor.shutting_down
def test_ray_v2_executor_shutdown():
"""Validate graceful shutdown: ray.kill() terminates all worker actors."""
executor = RayExecutorV2(vllm_config=create_vllm_config(tensor_parallel_size=2))
assert executor.rpc_broadcast_mq is not None
assert len(executor.response_mqs) == executor.world_size
actors = [h.actor for h in executor.ray_worker_handles]
executor.shutdown()
for actor in actors:
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.wait_for_init.remote(), timeout=5)
assert executor.rpc_broadcast_mq is None
assert len(executor.response_mqs) == 0
@pytest.mark.parametrize(
"executor",
[create_vllm_config(tensor_parallel_size=2)],
indirect=True,
)
def test_ray_v2_run_refs_stored_for_monitoring(executor):
"""Validate worker handles store run_ref for monitoring."""
for handle in executor.ray_worker_handles:
assert handle.run_ref is not None
ready, _ = ray.wait([handle.run_ref], timeout=0)
assert len(ready) == 0, "run_ref should be pending"
@pytest.mark.parametrize("tp_size, pp_size", [(2, 1), (2, 2)])
def test_ray_v2_single_node_generation(tp_size, pp_size):
"""End-to-end LLM generation with RayExecutorV2."""
llm = LLM(
model=MODEL,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
distributed_executor_backend="ray",
enforce_eager=True,
max_model_len=256,
gpu_memory_utilization=0.3,
)
try:
prompts = [
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]
outputs = llm.generate(prompts)
assert len(outputs) == len(prompts)
for output in outputs:
assert len(output.outputs) > 0
assert len(output.outputs[0].text) > 0
finally:
llm.llm_engine.model_executor.shutdown()
del llm
gc.collect()
@pytest.mark.parametrize(
"bundle_indices, expected_bundle_ids, create_placement_group",
[("2,3", [2, 3], 4), ("3,2", [3, 2], 4)],
indirect=["create_placement_group"],
)
def test_ray_v2_bundle_indices_env(
bundle_indices, expected_bundle_ids, create_placement_group, monkeypatch
):
"""Validate explicit VLLM_RAY_BUNDLE_INDICES bundle placement."""
monkeypatch.setenv("VLLM_RAY_BUNDLE_INDICES", bundle_indices)
vllm_config = create_vllm_config(
tensor_parallel_size=2,
placement_group=create_placement_group,
)
executor = RayExecutorV2(vllm_config=vllm_config)
try:
actual = [
h.bundle_id_idx
for h in sorted(executor.ray_worker_handles, key=lambda h: h.rank)
]
assert actual == expected_bundle_ids
assert_executor(executor, tp_size=2, pp_size=1)
finally:
executor.shutdown()
@pytest.mark.parametrize(
"bundle_indices, expected_error, create_placement_group",
[
("1,1", "cannot have duplicate values,", 4),
("0,1,2", "must have the same size", 4),
],
indirect=["create_placement_group"],
)
def test_ray_v2_invalid_bundle_indices(
bundle_indices, expected_error, create_placement_group, monkeypatch
):
"""Validate invalid bundle indices are rejected."""
monkeypatch.setenv("VLLM_RAY_BUNDLE_INDICES", bundle_indices)
vllm_config = create_vllm_config(
tensor_parallel_size=2, placement_group=create_placement_group
)
with pytest.raises(AssertionError, match=expected_error):
RayExecutorV2(vllm_config=vllm_config)
@pytest.mark.parametrize("tp_size, pp_size", [(2, 1), (2, 2)])
def test_ray_v2_single_node_generation_with_pg(tp_size, pp_size):
"""E2E LLM generation with a user-provided placement group."""
ensure_ray_initialized()
bundles = [{"GPU": 1, "CPU": 1} for _ in range(tp_size * pp_size)]
pg = ray.util.placement_group(bundles, strategy="PACK")
ray.get(pg.ready())
try:
with patch.object(ray.util, "get_current_placement_group", return_value=pg):
llm = LLM(
model=MODEL,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
distributed_executor_backend="ray",
enforce_eager=True,
max_model_len=256,
gpu_memory_utilization=0.3,
)
prompts = [
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]
outputs = llm.generate(prompts)
assert len(outputs) == len(prompts)
for output in outputs:
assert len(output.outputs) > 0
assert len(output.outputs[0].text) > 0
finally:
llm.llm_engine.model_executor.shutdown()
del llm
gc.collect()

View File

@@ -0,0 +1,209 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Orchestration-level integration tests for RayExecutorV2.
"""
import gc
import os
import pathlib
import pytest
import ray
pytestmark = pytest.mark.usefixtures("enable_ray_v2_backend")
MODEL = "facebook/opt-125m"
def _get_env_var(worker, name):
return os.environ.get(name)
def _ray_init():
"""Start Ray with the project root on workers' PYTHONPATH.
Without this, workers cannot unpickle actor classes defined in the
``tests`` package, causing FunctionActorManager to fall back to
TemporaryActor which drops async method signatures."""
project_root = str(pathlib.Path(__file__).resolve().parents[2])
ray.init(
ignore_reinit_error=True,
runtime_env={"env_vars": {"PYTHONPATH": project_root}},
)
@pytest.fixture
def ray_init():
_ray_init()
class _AsyncLLMActor:
def start(self, pg, bundle_indices=None, ray_runtime_env=None):
os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1"
# Needed so collective_rpc can pickle _get_env_var over the
# AsyncLLM -> EngineCore ZMQ boundary.
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
if bundle_indices is not None:
os.environ["VLLM_RAY_BUNDLE_INDICES"] = bundle_indices
else:
os.environ.pop("VLLM_RAY_BUNDLE_INDICES", None)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.executor.abstract import Executor
engine_args = AsyncEngineArgs(
model=MODEL,
tensor_parallel_size=2,
distributed_executor_backend="ray",
enforce_eager=True,
max_model_len=256,
gpu_memory_utilization=0.8,
)
vllm_config = engine_args.create_engine_config()
vllm_config.parallel_config.placement_group = pg
if ray_runtime_env is not None:
vllm_config.parallel_config.ray_runtime_env = ray_runtime_env
executor_class = Executor.get_class(vllm_config)
self.engine = AsyncLLM(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
log_requests=False,
)
async def generate(self, prompt):
from vllm.sampling_params import SamplingParams
params = SamplingParams(max_tokens=16)
result = None
async for output in self.engine.generate(
prompt, params, request_id="test_request_id"
):
result = output
assert result is not None
return result.outputs[0].text
async def generate_and_get_worker_envs(self, prompt, env_names):
from vllm.sampling_params import SamplingParams
params = SamplingParams(max_tokens=16)
result = None
async for output in self.engine.generate(
prompt, params, request_id="test_request_id"
):
result = output
assert result is not None
text = result.outputs[0].text
env_results = {}
for name in env_names:
vals = await self.engine.collective_rpc(
_get_env_var, timeout=10, args=(name,)
)
env_results[name] = vals
return text, env_results
def shutdown(self):
if engine := getattr(self, "engine", None):
engine.shutdown()
del self.engine
gc.collect()
AsyncLLMActor = ray.remote(num_cpus=0, max_concurrency=1)(_AsyncLLMActor)
def test_multi_replicas(ray_init):
pg1 = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK")
pg2 = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK")
ray.get([pg1.ready(), pg2.ready()])
actor1 = AsyncLLMActor.remote()
actor2 = AsyncLLMActor.remote()
ray.get(actor1.start.remote(pg1))
ray.get(actor2.start.remote(pg2))
out1, out2 = ray.get(
[
actor1.generate.remote("Hello world"),
actor2.generate.remote("Hello world"),
]
)
assert len(out1) > 0
assert len(out2) > 0
def test_multi_replicas_with_bundle_indices(ray_init):
pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 4, strategy="PACK")
ray.get(pg.ready())
actor1 = AsyncLLMActor.remote()
actor2 = AsyncLLMActor.remote()
ray.get(actor1.start.remote(pg, bundle_indices="2,1"))
ray.get(actor2.start.remote(pg, bundle_indices="0,3"))
out1, out2 = ray.get(
[
actor1.generate.remote("Hello world"),
actor2.generate.remote("Hello world"),
]
)
assert len(out1) > 0
assert len(out2) > 0
def test_env_var_and_runtime_env_propagation():
"""
Verify env vars (NCCL_, HF_) and parallel_config.ray_runtime_env
propagate to RayWorkerProc actors.
"""
sentinel_vars = {
"NCCL_DEBUG": "INFO",
"HF_TOKEN": "test_sentinel_token",
}
for k, v in sentinel_vars.items():
os.environ[k] = v
try:
# Called directly (not via the ray_init fixture) because sentinel
# env vars must be in os.environ before ray.init() so that Ray
# worker processes inherit them.
_ray_init()
pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK")
ray.get(pg.ready())
# Include the project root so that RayWorkerProc actors can
# unpickle _get_env_var.
project_root = str(pathlib.Path(__file__).resolve().parents[2])
ray_runtime_env = {
"env_vars": {
"RAY_RUNTIME_ENV_TEST": "ray_runtime_env",
"PYTHONPATH": project_root,
},
}
actor = AsyncLLMActor.remote()
ray.get(actor.start.remote(pg, ray_runtime_env=ray_runtime_env))
all_env_names = list(sentinel_vars) + ["RAY_RUNTIME_ENV_TEST"]
text, env_results = ray.get(
actor.generate_and_get_worker_envs.remote("Hello world", all_env_names)
)
assert len(text) > 0
for name, expected in sentinel_vars.items():
for val in env_results[name]:
assert val == expected
for val in env_results["RAY_RUNTIME_ENV_TEST"]:
assert val == "ray_runtime_env"
finally:
for k in sentinel_vars:
os.environ.pop(k, None)

View File

@@ -523,3 +523,20 @@ def test_human_readable_model_len():
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
with pytest.raises(ArgumentError):
parser.parse_args(["--max-model-len", invalid])
def test_ir_op_priority():
from vllm.config.kernel import IrOpPriorityConfig, KernelConfig
ir_op_priority = IrOpPriorityConfig(rms_norm=["vllm_c"])
cfg1 = EngineArgs(ir_op_priority=ir_op_priority).create_engine_config()
cfg2 = EngineArgs(
kernel_config=KernelConfig(ir_op_priority=ir_op_priority)
).create_engine_config()
assert cfg1.kernel_config.ir_op_priority == cfg2.kernel_config.ir_op_priority
with pytest.raises(ValueError, match="rms_norm"):
_ = EngineArgs(
ir_op_priority=ir_op_priority,
kernel_config=KernelConfig(ir_op_priority=ir_op_priority),
).create_engine_config()

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,325 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Generative Scoring API.
Tests cover:
1. Protocol models (request/response construction)
2. Probability computation (softmax normalization)
3. Input validation
4. Score formula: P(token[0]) / (P(token[0]) + P(token[1]))
5. Prompt building and item ordering
"""
import math
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.generative_scoring.serving import (
GenerativeScoringItemResult,
GenerativeScoringRequest,
GenerativeScoringResponse,
OpenAIServingGenerativeScoring,
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "Qwen/Qwen3-0.6B"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
vocab_size = 151936
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def get_vocab_size(self):
return self.vocab_size
def _create_mock_engine():
"""Create a mock AsyncLLM engine."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
# renderer is accessed by OpenAIServing.__init__ and serving.py
mock_renderer = MagicMock()
mock_renderer.tokenizer = get_tokenizer(MODEL_NAME)
mock_engine.renderer = mock_renderer
return mock_engine
def _create_serving(mock_engine) -> OpenAIServingGenerativeScoring:
"""Create an OpenAIServingGenerativeScoring instance with mocks."""
models = OpenAIServingModels(
engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
)
return OpenAIServingGenerativeScoring(mock_engine, models, request_logger=None)
def _create_mock_request_output(logprobs_dict: dict[int, float]) -> RequestOutput:
"""Create a mock RequestOutput with specified logprobs."""
logprobs_with_objs = {
tid: Logprob(logprob=lp, rank=i + 1)
for i, (tid, lp) in enumerate(logprobs_dict.items())
}
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[100],
cumulative_logprob=-1.0,
logprobs=[logprobs_with_objs],
finish_reason="length",
)
return RequestOutput(
request_id="test-request",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output],
finished=True,
)
class TestProtocolModels:
"""Tests for GenerativeScoringRequest and GenerativeScoringResponse."""
def test_request_and_response_all_fields(self):
"""Test request construction with all field types and response structure."""
# Test request with string inputs
req_str = GenerativeScoringRequest(
query="Is this the capital?",
items=["Paris", "London"],
label_token_ids=[9454, 2753],
)
assert req_str.query == "Is this the capital?"
assert req_str.items == ["Paris", "London"]
assert req_str.label_token_ids == [9454, 2753]
assert req_str.apply_softmax is True # default
assert req_str.item_first is False # default
assert req_str.add_special_tokens is True # default
# Test request with pre-tokenized inputs and custom options
req_tok = GenerativeScoringRequest(
query=[100, 200, 300],
items=[[400, 500], [600, 700]],
label_token_ids=[1234, 5678],
apply_softmax=False,
item_first=True,
add_special_tokens=False,
)
assert req_tok.query == [100, 200, 300]
assert req_tok.items == [[400, 500], [600, 700]]
assert req_tok.apply_softmax is False
assert req_tok.item_first is True
assert req_tok.add_special_tokens is False
# Test response structure
response = GenerativeScoringResponse(
model="test-model",
data=[
GenerativeScoringItemResult(index=0, score=0.7),
GenerativeScoringItemResult(index=1, score=0.4),
],
usage={"prompt_tokens": 10, "total_tokens": 12, "completion_tokens": 2},
)
assert response.object == "list"
assert response.model == "test-model"
assert len(response.data) == 2
assert response.data[0].score == 0.7
assert response.data[0].object == "score"
assert response.data[1].score == 0.4
assert response.usage.prompt_tokens == 10
class TestProbabilityComputation:
"""Tests for _compute_probabilities with both softmax modes."""
@pytest.mark.parametrize(
"label_logprobs,apply_softmax,should_sum_to_one",
[
({100: -1.0, 200: -2.0}, True, True),
({100: -100.0, 200: -100.5}, True, True), # numerical stability
({100: -1.0, 200: -2.0}, False, False),
],
ids=["softmax_basic", "softmax_extreme_values", "true_probs"],
)
def test_compute_probabilities(
self, label_logprobs, apply_softmax, should_sum_to_one
):
"""Test probability computation for softmax and true probability modes."""
serving = OpenAIServingGenerativeScoring.__new__(OpenAIServingGenerativeScoring)
probs = serving._compute_probabilities(
label_logprobs, apply_softmax=apply_softmax
)
# Verify sum behavior
total = sum(probs.values())
if should_sum_to_one:
assert abs(total - 1.0) < 1e-6
else:
assert total < 1.0
# Verify math
if apply_softmax:
max_lp = max(label_logprobs.values())
exp_vals = {k: math.exp(v - max_lp) for k, v in label_logprobs.items()}
sum_exp = sum(exp_vals.values())
for tid, lp in label_logprobs.items():
assert abs(probs[tid] - exp_vals[tid] / sum_exp) < 1e-9
else:
for tid, lp in label_logprobs.items():
assert abs(probs[tid] - math.exp(lp)) < 1e-9
def test_score_formula(self):
"""Test the score formula: P(token[0]) / (P(token[0]) + P(token[1]))."""
serving = OpenAIServingGenerativeScoring.__new__(OpenAIServingGenerativeScoring)
# With logprobs -0.5 and -2.0, softmax gives higher prob to first token
logprobs = {9454: -0.5, 2753: -2.0}
probs = serving._compute_probabilities(logprobs, apply_softmax=True)
# Score = P(9454) / (P(9454) + P(2753)) = P(9454) since they sum to 1
score = probs[9454]
# Manual calculation
exp_0 = math.exp(-0.5)
exp_1 = math.exp(-2.0)
expected_score = exp_0 / (exp_0 + exp_1)
assert abs(score - expected_score) < 1e-9
assert score > 0.5 # First token has higher logprob, so higher probability
class TestValidation:
"""Tests for input validation errors."""
@pytest.mark.asyncio
@pytest.mark.parametrize(
"request_kwargs,expected_error",
[
(
{"query": "q", "items": ["i"], "label_token_ids": [999999, 999998]},
"out of vocabulary",
),
(
{"query": "q", "items": [], "label_token_ids": [100, 200]},
"at least one item",
),
],
ids=["invalid_token_id", "empty_items"],
)
async def test_validation_errors(self, request_kwargs, expected_error):
"""Test that invalid inputs return appropriate errors."""
mock_engine = _create_mock_engine()
serving = _create_serving(mock_engine)
request = GenerativeScoringRequest(model=MODEL_NAME, **request_kwargs)
result = await serving.create_generative_scoring(request, None)
assert isinstance(result, ErrorResponse)
assert expected_error in result.error.message.lower()
class TestPromptBuilding:
"""Tests for prompt construction and item ordering."""
@pytest.mark.asyncio
@pytest.mark.parametrize(
"item_first,expected",
[
(False, [[100, 101, 200, 201], [100, 101, 300, 301]]), # query + item
(True, [[200, 201, 100, 101], [300, 301, 100, 101]]), # item + query
],
ids=["query_first", "item_first"],
)
async def test_item_ordering(self, item_first, expected):
"""Test that item_first flag controls prompt concatenation order."""
mock_engine = _create_mock_engine()
serving = _create_serving(mock_engine)
request = GenerativeScoringRequest(
query=[100, 101],
items=[[200, 201], [300, 301]],
label_token_ids=[500, 501],
item_first=item_first,
)
engine_inputs, _ = await serving._build_prompts(
request, MagicMock(), max_model_len=4096
)
for i, exp in enumerate(expected):
assert engine_inputs[i]["prompt_token_ids"] == exp
class TestGeneration:
"""Tests for the full generation flow with mocked engine."""
@pytest.mark.asyncio
async def test_successful_generation(self):
"""Test successful score generation returns valid response."""
mock_engine = _create_mock_engine()
serving = _create_serving(mock_engine)
mock_logprobs = {1234: -0.5, 5678: -2.0, 100: -3.0}
mock_output = _create_mock_request_output(mock_logprobs)
async def mock_generate(*args, **kwargs):
yield mock_output
mock_engine.generate = mock_generate
request = GenerativeScoringRequest(
model=MODEL_NAME,
query="Is Paris the capital?",
items=["Yes", "No"],
label_token_ids=[1234, 5678],
)
result = await serving.create_generative_scoring(request, None)
assert isinstance(result, GenerativeScoringResponse)
assert len(result.data) == 2
for item_result in result.data:
assert 0.0 <= item_result.score <= 1.0
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""End-to-end tests for the Generative Scoring API.
Tests verify the full HTTP request/response flow using RemoteOpenAIServer.
"""
import pytest
import requests
from ....utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"512",
"--enforce-eager",
"--max-num-seqs",
"32",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
class TestGenerativeScoringAPI:
"""End-to-end tests for the Generative Scoring API."""
@pytest.mark.asyncio
async def test_basic_score_and_response_structure(self, server: RemoteOpenAIServer):
"""Test basic generative scoring request and verify response structure."""
response = requests.post(
server.url_for("generative_scoring"),
json={
"model": MODEL_NAME,
"query": "Is Paris the capital of France? Answer Yes or No: ",
"items": ["Paris is beautiful.", "London is rainy."],
"label_token_ids": [9454, 2753],
},
)
assert response.status_code == 200, f"Response: {response.text}"
data = response.json()
# Verify response structure
assert data["id"].startswith("generative-scoring-")
assert data["object"] == "list"
assert "model" in data
assert "usage" in data
assert len(data["data"]) == 2
# Verify each result
for i, result in enumerate(data["data"]):
assert result["index"] == i
assert result["object"] == "score"
assert 0.0 <= result["score"] <= 1.0
# Verify usage tracking
usage = data["usage"]
assert usage["prompt_tokens"] > 0
assert usage["completion_tokens"] > 0
assert (
usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"]
)
@pytest.mark.asyncio
async def test_multiple_items(self, server: RemoteOpenAIServer):
"""Test generative scoring request with multiple items."""
response = requests.post(
server.url_for("generative_scoring"),
json={
"model": MODEL_NAME,
"query": "Is this city a capital? ",
"items": ["Paris", "London", "Berlin", "New York", "Tokyo"],
"label_token_ids": [9454, 2753],
},
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 5
@pytest.mark.asyncio
async def test_validation_missing_label_token_ids(self, server: RemoteOpenAIServer):
"""Test that missing label_token_ids returns a validation error."""
response = requests.post(
server.url_for("generative_scoring"),
json={
"model": MODEL_NAME,
"query": "Test query",
"items": ["item1", "item2"],
},
)
# Missing required field returns 400 (manual JSON parsing)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_validation_empty_items(self, server: RemoteOpenAIServer):
"""Test that empty items returns an error."""
response = requests.post(
server.url_for("generative_scoring"),
json={
"model": MODEL_NAME,
"query": "Test query",
"items": [],
"label_token_ids": [100, 200],
},
)
assert response.status_code == 400
@pytest.mark.asyncio
@pytest.mark.parametrize(
"label_token_ids,expected_status",
[
([9999999999, 9999999998], 400), # Out of vocab range
],
ids=["invalid_token_ids"],
)
async def test_validation_errors(
self, server: RemoteOpenAIServer, label_token_ids, expected_status
):
"""Test validation errors for various invalid inputs."""
response = requests.post(
server.url_for("generative_scoring"),
json={
"model": MODEL_NAME,
"query": "Test query",
"items": ["item1"],
"label_token_ids": label_token_ids,
},
)
assert response.status_code == expected_status
@pytest.mark.asyncio
async def test_score_consistency(self, server: RemoteOpenAIServer):
"""Test that scores are deterministic across identical requests."""
request_body = {
"model": MODEL_NAME,
"query": "Is this consistent? ",
"items": ["Yes it is."],
"label_token_ids": [100, 200],
}
r1 = requests.post(server.url_for("generative_scoring"), json=request_body)
r2 = requests.post(server.url_for("generative_scoring"), json=request_body)
assert r1.status_code == 200 and r2.status_code == 200
r1_score = r1.json()["data"][0]["score"]
r2_score = r2.json()["data"][0]["score"]
assert abs(r1_score - r2_score) < 1e-6
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -234,7 +234,7 @@ async def test_score_api_queries_str_documents_image_url_plus_text_content(
assert score.id is not None
assert score.data is not None
assert len(score.data) == 1
assert score.usage.prompt_tokens == 108
assert score.usage.prompt_tokens == 107
assert_score(
score.data[0].score, TEXT_VS_TEXT_PLUS_IMAGE, backend, "text_vs_text_plus_image"
)
@@ -264,7 +264,7 @@ async def test_score_api_queries_str_documents_list(
assert score.id is not None
assert score.data is not None
assert len(score.data) == 4
assert score.usage.prompt_tokens == 368
assert score.usage.prompt_tokens == 367
assert_score(score.data[0].score, TEXT_VS_TEXT, backend, "list[0]_text_vs_text")
assert_score(score.data[1].score, TEXT_VS_TEXT, backend, "list[1]_text_vs_text")
assert_score(score.data[2].score, TEXT_VS_IMAGE, backend, "list[2]_text_vs_image")
@@ -353,7 +353,7 @@ async def test_score_api_queries_list_documents_list(
assert score.id is not None
assert score.data is not None
assert len(score.data) == 4
assert score.usage.prompt_tokens == 368
assert score.usage.prompt_tokens == 367
assert_score(score.data[0].score, TEXT_VS_TEXT, backend, "paired[0]_text_vs_text")
assert_score(score.data[1].score, TEXT_VS_TEXT, backend, "paired[1]_text_vs_text")
assert_score(score.data[2].score, TEXT_VS_IMAGE, backend, "paired[2]_text_vs_image")

View File

@@ -26,13 +26,18 @@ TEXTS_2 = [
]
@pytest.fixture(scope="module")
def server():
@pytest.fixture(scope="module", params=[True, False])
def server(request):
args = [
"--max-model-len",
str(MAX_MODEL_LEN),
]
# Test run pooling score MaxSim on worker side (GPU)
# aka flash-late-interaction
if not request.param:
args += ["--no-enable-flash-late-interaction"]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Regression test: calling ``/tokenize`` with multimodal data followed by
``/v1/chat/completions`` with the same data must not cause an error.
Ensures that the ``/tokenize`` endpoint does not pollute internal caches
(e.g. multimodal feature caches) and that a subsequent
``/v1/chat/completions`` request with the same multimodal payload
completes successfully.
"""
import json
import openai
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--max-num-seqs",
"5",
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"image": 1}),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_tokenize_then_chat_completion_with_image(
client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
local_asset_server,
):
"""Tokenize a multimodal message, then send the same message to chat
completions. The chat completion must succeed (not 500)."""
image_url = local_asset_server.url_for("stop_sign.jpg")
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Describe this image briefly."},
],
}
]
tok_resp = requests.post(
server.url_for("tokenize"),
json={"model": MODEL_NAME, "messages": messages},
)
tok_resp.raise_for_status()
tok_data = tok_resp.json()
assert tok_data["count"] > 0, "Tokenization must return tokens"
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
temperature=0.0,
)
assert chat_completion.choices[0].message.content, (
"Chat completion must produce non-empty content after tokenize"
)

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter"
env:
VLLM_ROCM_USE_AITER: "1"

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton"

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN"
env:
VLLM_ROCM_USE_AITER: "1"

View File

@@ -1,3 +1,6 @@
# GFX950 model configurations for GPQA evaluation
# Tests different environment variable combinations
gpt-oss-20b-rocm-baseline.yaml
gpt-oss-20b-rocm-baseline.yaml
gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml

Some files were not shown because too many files have changed in this diff Show More