Compare commits

..

18 Commits

Author SHA1 Message Date
013b73e9b2 Fix managed KV cache: use __cuda_array_interface__ instead of UntypedStorage.from_blob
UntypedStorage.from_blob was removed in PyTorch 2.11+. Use the
standard __cuda_array_interface__ protocol to wrap cudaMallocManaged
pointers into PyTorch tensors — this works across all PyTorch versions.

Also removed cudaMemAdvise calls — ctypes struct passing for
cudaMemLocation is broken on ARM64 (returns EINVAL). The advise hints
are optional; pages will page-fault to GPU on-demand regardless.

CPU memset (ctypes.memset) is still used instead of cudaMemset to
avoid forcing all pages into HBM during zeroing.
2026-04-12 06:56:52 +00:00
c77342da87 KV cache: prefer CPU placement, zero via CPU not GPU
Two critical fixes for managed memory KV cache allocation:

1. Preferred location set to CPU (not GPU). The KV cache is too large
   for HBM (50-100+ GiB). Setting preferred location to GPU causes the
   driver to try migrating the entire allocation to HBM → OOM. With
   CPU as preferred location, pages stay in LPDDR/EGM and page-fault
   to GPU on-demand during attention ops.

2. Zero memory via CPU memset (not cudaMemset). cudaMemset runs on the
   device, forcing ALL pages to migrate to GPU before zeroing — exactly
   what we're trying to avoid. CPU memset keeps pages in LPDDR.

Also added SetAccessedBy(GPU) so the GPU can access pages remotely
over C2C NVLink without triggering page migration back to GPU.
2026-04-12 03:44:16 +00:00
7f35bc4158 Targeted KV cache managed memory allocation
Instead of swapping the global CUDA allocator (which broke cuBLAS),
allocate KV cache via cudaMallocManaged directly in
_allocate_kv_cache_tensors(). Controlled by
VLLM_KV_CACHE_USE_MANAGED_MEMORY env var.

Model weights and compute intermediates stay in HBM via default
cudaMalloc. Only KV cache spills into EGM/LPDDR.
2026-04-11 02:14:34 +00:00
487dd34e04 Selective prefetch: only prefetch allocations <2 GiB to GPU
Model weights (small tensors) must be in HBM for cuBLAS GEMM ops
which can't page-fault into managed memory. KV cache blocks are
large and numerous — prefetching them all fills HBM and causes
OOM. The 2 GiB threshold separates compute data from cache data.
2026-04-10 14:58:57 +00:00
a15f86ecfa Remove cudaMemPrefetchAsync from managed allocator
Eager prefetching was filling HBM+EGM, causing subsequent
cudaMallocManaged calls to fail after model loading. On GH200
with EGM, pages should migrate on-demand via hardware page faults
over C2C NVLink. The cudaMemAdviseSetPreferredLocation(GPU) hint
is sufficient to prefer GPU placement with LPDDR fallback.
2026-04-10 05:58:11 +00:00
Michael
2a69949bda [Bugfix]: Fix Gemma4ToolParser.__init__() missing tools parameter (#38847)
Signed-off-by: Michael Hospedales <hospedales@me.com>
(cherry picked from commit bb39382b2b)
2026-04-02 16:45:38 -07:00
Luciano Martins
8adcf8c40a feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Luciano Martins <lucianomartins@google.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
(cherry picked from commit 08ed2b9688)
2026-04-02 11:49:53 -07:00
khluu
cfad6a509c Revert "[Bugfix] Restrict TRTLLM attention to SM100, fixing GB300 (SM103) hang (#38730)"
This reverts commit c284a6671c.
2026-04-01 15:14:58 -07:00
Stefano Castagnetta
c284a6671c [Bugfix] Restrict TRTLLM attention to SM100, fixing GB300 (SM103) hang (#38730)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
(cherry picked from commit 6183cae1bd)
2026-04-01 12:11:03 -07:00
Chauncey
3a30a1a6a8 [Misc] Rename think_start_str/think_end_str to reasoning_start_str/reasoning_end_str (#38242)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
(cherry picked from commit cbe7d18096)
2026-04-01 12:10:53 -07:00
Juan Pérez de Algaba
29982d48b3 (security) Enforce frame limit in VideoMediaIO (#38636)
Signed-off-by: jperezde <jperezde@redhat.com>
(cherry picked from commit 58ee614221)
2026-04-01 12:10:40 -07:00
Yifan Qiao
1dbbafd3f3 [Feat][v1] Simple yet General CPU KV Cache Offloading (#37160)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
(cherry picked from commit 91e4521f9f)
2026-04-01 01:03:14 -07:00
Lucas Wilkinson
0ee3b7fc3d [Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit eb47454987)
2026-04-01 01:02:58 -07:00
Matthew Bonanni
268bed9cf3 [Bugfix][Async] Fix async spec decoding with hybrid models (#38556)
Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com>
(cherry picked from commit 757068dc65)
2026-04-01 01:02:35 -07:00
Jiangyun Zhu
bcc0fdd0f3 [CI] fix LM Eval Qwen3.5 Models (B200) (#38632)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
(cherry picked from commit ea7bfde6e4)
2026-04-01 01:02:20 -07:00
wang.yuqi
69b8bd4b33 [CI Failure] pin colmodernvbert revision (#38612)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit 719735d6c5)
2026-04-01 01:02:04 -07:00
Li, Jiang
12449f9492 [Bugfix][CPU] Skip set_num_threads after thread binding (#38535)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
(cherry picked from commit 6557f4937f)
2026-03-30 23:01:42 -07:00
haosdent
b92312dfd7 [CI] Fix SPLADE pooler test broken by #38139 (#38495)
Signed-off-by: haosdent <haosdent@gmail.com>
(cherry picked from commit a08b7733fd)
2026-03-30 21:52:13 -07:00
612 changed files with 10021 additions and 29332 deletions

View File

@@ -5,7 +5,6 @@ steps:
depends_on: [] depends_on: []
device: amd_cpu device: amd_cpu
no_plugin: true no_plugin: true
soft_fail: true
commands: commands:
- > - >
docker build docker build
@@ -21,3 +20,11 @@ steps:
- docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}" - docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}"
env: env:
DOCKER_BUILDKIT: "1" DOCKER_BUILDKIT: "1"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 1
- exit_status: -10 # Agent was lost
limit: 1
- exit_status: 1 # Machine occasionally fail
limit: 1

View File

@@ -13,14 +13,12 @@ steps:
- tests/kernels/attention/test_cpu_attn.py - tests/kernels/attention/test_cpu_attn.py
- tests/kernels/moe/test_cpu_fused_moe.py - tests/kernels/moe/test_cpu_fused_moe.py
- tests/kernels/test_onednn.py - tests/kernels/test_onednn.py
- tests/kernels/test_awq_int4_to_int8.py
commands: commands:
- | - |
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m " bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py pytest -x -v -s tests/kernels/test_onednn.py"
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
- label: CPU-Compatibility Tests - label: CPU-Compatibility Tests
depends_on: [] depends_on: []

View File

@@ -36,7 +36,6 @@
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm", "backend": "vllm",
"ignore-eos": "", "ignore-eos": "",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },

View File

@@ -22,7 +22,6 @@
"hf_split": "test", "hf_split": "test",
"no_stream": "", "no_stream": "",
"no_oversample": "", "no_oversample": "",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },

View File

@@ -26,7 +26,6 @@
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm", "backend": "vllm",
"ignore-eos": "", "ignore-eos": "",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },

View File

@@ -26,7 +26,6 @@
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm", "backend": "vllm",
"ignore-eos": "", "ignore-eos": "",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },

View File

@@ -21,7 +21,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },
@@ -48,7 +47,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },
@@ -75,7 +73,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },
@@ -103,7 +100,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },
@@ -131,7 +127,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },
@@ -156,7 +151,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
} }

View File

@@ -13,7 +13,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },
@@ -31,7 +30,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },
@@ -49,7 +47,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
}, },
@@ -70,7 +67,6 @@
"backend": "vllm", "backend": "vllm",
"dataset_name": "sharegpt", "dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200 "num_prompts": 200
} }
} }

View File

@@ -239,29 +239,13 @@ fi
# --- Docker housekeeping --- # --- Docker housekeeping ---
cleanup_docker cleanup_docker
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin "$REGISTRY"
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 936637512419.dkr.ecr.us-east-1.amazonaws.com
# --- Build or pull test image --- # --- Build or pull test image ---
IMAGE="${IMAGE_TAG_XPU:-${image_name}}" if [[ -n "${IMAGE_TAG_XPU:-}" ]]; then
echo "Using prebuilt XPU image: ${IMAGE_TAG_XPU}"
echo "Using image: ${IMAGE}" docker pull "${IMAGE_TAG_XPU}"
if docker image inspect "${IMAGE}" >/dev/null 2>&1; then
echo "Image already exists locally, skipping pull"
else else
echo "Image not found locally, waiting for lock..." echo "Using prebuilt XPU image: ${image_name}"
docker pull "${image_name}"
flock /tmp/docker-pull.lock bash -c "
if docker image inspect '${IMAGE}' >/dev/null 2>&1; then
echo 'Image already pulled by another runner'
else
echo 'Pulling image...'
timeout 900 docker pull '${IMAGE}'
fi
"
echo "Pull step completed"
fi fi
remove_docker_container() { remove_docker_container() {

View File

@@ -42,7 +42,6 @@ docker run \
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192 python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
python3 examples/basic/offline_inference/generate.py --model OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc --block-size 64 --enforce-eager --max-model-len 8192
cd tests cd tests
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
pytest -v -s v1/engine pytest -v -s v1/engine

View File

@@ -790,7 +790,7 @@ steps:
- tests/kernels/helion/ - tests/kernels/helion/
- vllm/platforms/rocm.py - vllm/platforms/rocm.py
commands: commands:
- pip install helion==0.3.3 - pip install helion
- pytest -v -s kernels/helion/ - pytest -v -s kernels/helion/

View File

@@ -2,6 +2,14 @@ group: Benchmarks
depends_on: depends_on:
- image-build - image-build
steps: steps:
- label: Benchmarks
timeout_in_minutes: 20
working_dir: "/vllm-workspace/.buildkite"
source_file_dependencies:
- benchmarks/
commands:
- bash scripts/run-benchmarks.sh
- label: Benchmarks CLI Test - label: Benchmarks CLI Test
timeout_in_minutes: 20 timeout_in_minutes: 20
source_file_dependencies: source_file_dependencies:

View File

@@ -72,7 +72,6 @@ steps:
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes - vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
- tests/compile/passes/test_fusion_attn.py - tests/compile/passes/test_fusion_attn.py
- tests/compile/passes/test_mla_attn_quant_fusion.py
- tests/compile/passes/test_silu_mul_quant_fusion.py - tests/compile/passes/test_silu_mul_quant_fusion.py
- tests/compile/passes/distributed/test_fusion_all_reduce.py - tests/compile/passes/distributed/test_fusion_all_reduce.py
- tests/compile/fullgraph/test_full_graph.py - tests/compile/fullgraph/test_full_graph.py
@@ -80,7 +79,6 @@ steps:
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell # b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
- nvidia-smi - nvidia-smi
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER - pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
- pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py - pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_devices=2 is not set # this runner has 2 GPUs available even though num_devices=2 is not set
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py

View File

@@ -224,20 +224,6 @@ steps:
commands: commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 $IMAGE_TAG "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code" - ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 $IMAGE_TAG "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code"
- label: MessageQueue TCP Multi-Node (2 GPUs)
timeout_in_minutes: 10
working_dir: "/vllm-workspace/tests"
num_devices: 1
num_nodes: 2
no_plugin: true
optional: true
source_file_dependencies:
- vllm/distributed/device_communicators/shm_broadcast.py
- vllm/distributed/parallel_state.py
- tests/distributed/test_mq_tcp_multinode.py
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 1 $IMAGE_TAG "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py" "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py"
- label: Distributed NixlConnector PD accuracy (4 GPUs) - label: Distributed NixlConnector PD accuracy (4 GPUs)
timeout_in_minutes: 30 timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
@@ -308,23 +294,3 @@ steps:
commands: commands:
- pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py - pytest -v -s distributed/test_pipeline_parallel.py
- label: RayExecutorV2 (4 GPUs)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:
- vllm/v1/executor/ray_executor_v2.py
- vllm/v1/executor/abstract.py
- vllm/v1/executor/multiproc_executor.py
- tests/distributed/test_ray_v2_executor.py
- tests/distributed/test_ray_v2_executor_e2e.py
- tests/distributed/test_pipeline_parallel.py
- tests/basic_correctness/test_basic_correctness.py
commands:
- export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1
- export NCCL_CUMEM_HOST_ENABLE=0
- pytest -v -s distributed/test_ray_v2_executor.py
- pytest -v -s distributed/test_ray_v2_executor_e2e.py
- pytest -v -s distributed/test_pipeline_parallel.py -k "ray"
- TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray"

View File

@@ -13,8 +13,8 @@ steps:
- pytest -v -s distributed/test_eplb_algo.py - pytest -v -s distributed/test_eplb_algo.py
- pytest -v -s distributed/test_eplb_utils.py - pytest -v -s distributed/test_eplb_utils.py
- label: EPLB Execution # 17min - label: EPLB Execution
timeout_in_minutes: 27 timeout_in_minutes: 20
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_devices: 4 num_devices: 4
source_file_dependencies: source_file_dependencies:

View File

@@ -2,16 +2,6 @@ group: Kernels
depends_on: depends_on:
- image-build - image-build
steps: steps:
- label: vLLM IR Tests
timeout_in_minutes: 10
working_dir: "/vllm-workspace/"
source_file_dependencies:
- vllm/ir
- vllm/kernels
commands:
- pytest -v -s tests/ir
- pytest -v -s tests/kernels/ir
- label: Kernels Core Operation Test - label: Kernels Core Operation Test
timeout_in_minutes: 75 timeout_in_minutes: 75
source_file_dependencies: source_file_dependencies:
@@ -29,7 +19,6 @@ steps:
- vllm/v1/attention - vllm/v1/attention
# TODO: remove this dependency (https://github.com/vllm-project/vllm/issues/32267) # TODO: remove this dependency (https://github.com/vllm-project/vllm/issues/32267)
- vllm/model_executor/layers/attention - vllm/model_executor/layers/attention
- vllm/utils/flashinfer.py
- tests/kernels/attention - tests/kernels/attention
commands: commands:
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
@@ -140,7 +129,7 @@ steps:
- vllm/utils/import_utils.py - vllm/utils/import_utils.py
- tests/kernels/helion/ - tests/kernels/helion/
commands: commands:
- pip install helion==0.3.3 - pip install helion
- pytest -v -s kernels/helion/ - pytest -v -s kernels/helion/

View File

@@ -18,6 +18,5 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal/generation/test_phi4siglip.py -v -s -m 'distributed(num_gpus=2)' - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/generation/test_phi4siglip.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'

30
.github/CODEOWNERS vendored
View File

@@ -2,20 +2,16 @@
# for more info about CODEOWNERS file # for more info about CODEOWNERS file
# This lists cover the "core" components of vLLM that require careful review # This lists cover the "core" components of vLLM that require careful review
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng @vadiklyutiy /vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery /vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
/vllm/lora @jeejeelee /vllm/lora @jeejeelee
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni /vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety /vllm/model_executor/layers/fused_moe @mgoin @pavanimajety
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
/vllm/model_executor/layers/mamba @tdoublep @tomeras91 /vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516 @vadiklyutiy /vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
/vllm/model_executor/model_loader @22quinn /vllm/model_executor/model_loader @22quinn
/vllm/model_executor/layers/batch_invariant.py @yewentao256 /vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/ir @ProExpertProg
/vllm/kernels/ @ProExpertProg @tjtanaa
/vllm/kernels/helion @ProExpertProg @zou3519
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa /vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni /vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
CMakeLists.txt @tlrmchlsmth @LucasWilkinson CMakeLists.txt @tlrmchlsmth @LucasWilkinson
@@ -51,9 +47,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/vllm/v1/attention @LucasWilkinson @MatthewBonanni /vllm/v1/attention @LucasWilkinson @MatthewBonanni
/vllm/v1/attention/backend.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill /vllm/v1/attention/backend.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill
/vllm/v1/attention/backends/mla @pavanimajety /vllm/v1/attention/backends/mla @pavanimajety
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety @vadiklyutiy /vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety
/vllm/v1/attention/backends/triton_attn.py @tdoublep /vllm/v1/attention/backends/triton_attn.py @tdoublep
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516 @vadiklyutiy /vllm/v1/attention/backends/gdn_attn.py @ZJY0516
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery /vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
/vllm/v1/sample @22quinn @houseroad @njhill /vllm/v1/sample @22quinn @houseroad @njhill
/vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni /vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni
@@ -75,9 +71,8 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao
/tests/distributed/test_same_node.py @youkaichao /tests/distributed/test_same_node.py @youkaichao
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche /tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
/tests/evals @mgoin @vadiklyutiy /tests/evals @mgoin
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256 /tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
/tests/kernels/ir @ProExpertProg @tjtanaa
/tests/models @DarkLight1337 @ywang96 /tests/models @DarkLight1337 @ywang96
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche /tests/multimodal @DarkLight1337 @ywang96 @NickLucche
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety /tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
@@ -87,7 +82,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery /tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
/tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/weight_loading @mgoin @youkaichao @yewentao256
/tests/lora @jeejeelee /tests/lora @jeejeelee
/tests/models/language/generation/test_hybrid.py @tdoublep @tomeras91 /tests/models/language/generation/test_hybrid.py @tdoublep
/tests/v1/kv_connector/nixl_integration @NickLucche /tests/v1/kv_connector/nixl_integration @NickLucche
/tests/v1/kv_connector @ApostaC @orozery /tests/v1/kv_connector @ApostaC @orozery
/tests/v1/kv_offload @ApostaC @orozery /tests/v1/kv_offload @ApostaC @orozery
@@ -131,14 +126,9 @@ mkdocs.yaml @hmellor
/vllm/platforms/xpu.py @jikunshang /vllm/platforms/xpu.py @jikunshang
/docker/Dockerfile.xpu @jikunshang /docker/Dockerfile.xpu @jikunshang
# Nemotron-specific files
/vllm/model_executor/models/*nemotron* @tomeras91
/vllm/transformers_utils/configs/*nemotron* @tomeras91
/tests/**/*nemotron* @tomeras91
# Qwen-specific files # Qwen-specific files
/vllm/model_executor/models/qwen* @sighingnow @vadiklyutiy /vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow
/vllm/transformers_utils/configs/qwen* @sighingnow @vadiklyutiy /vllm/model_executor/models/qwen* @sighingnow
# MTP-specific files # MTP-specific files
/vllm/model_executor/models/deepseek_mtp.py @luccafong /vllm/model_executor/models/deepseek_mtp.py @luccafong
@@ -154,7 +144,7 @@ mkdocs.yaml @hmellor
# Kernels # Kernels
/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep /vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep
/vllm/v1/attention/ops/triton_unified_attention.py @tdoublep /vllm/v1/attention/ops/triton_unified_attention.py @tdoublep
/vllm/model_executor/layers/fla @ZJY0516 @vadiklyutiy /vllm/model_executor/layers/fla @ZJY0516
# ROCm related: specify owner with write access to notify AMD folks for careful code review # ROCm related: specify owner with write access to notify AMD folks for careful code review
/vllm/**/*rocm* @tjtanaa /vllm/**/*rocm* @tjtanaa

View File

@@ -28,7 +28,6 @@ jobs:
}); });
const hasReadyLabel = pr.labels.some(l => l.name === 'ready'); const hasReadyLabel = pr.labels.some(l => l.name === 'ready');
const hasVerifiedLabel = pr.labels.some(l => l.name === 'verified');
const { data: mergedPRs } = await github.rest.search.issuesAndPullRequests({ const { data: mergedPRs } = await github.rest.search.issuesAndPullRequests({
q: `repo:${context.repo.owner}/${context.repo.repo} is:pr is:merged author:${pr.user.login}`, q: `repo:${context.repo.owner}/${context.repo.repo} is:pr is:merged author:${pr.user.login}`,
@@ -36,10 +35,10 @@ jobs:
}); });
const mergedCount = mergedPRs.total_count; const mergedCount = mergedPRs.total_count;
if (hasReadyLabel || hasVerifiedLabel || mergedCount >= 4) { if (hasReadyLabel || mergedCount >= 4) {
core.info(`Check passed: verified label=${hasVerifiedLabel}, ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`); core.info(`Check passed: ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
} else { } else {
core.setFailed(`PR must have the 'verified' or 'ready' (which also triggers tests) label or the author must have at least 4 merged PRs (found ${mergedCount}).`); core.setFailed(`PR must have the 'ready' label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
} }
pre-commit: pre-commit:

View File

@@ -39,7 +39,7 @@ repos:
rev: 0.11.1 rev: 0.11.1
hooks: hooks:
- id: pip-compile - id: pip-compile
args: [requirements/test.in, -c, requirements/common.txt, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"] args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
files: ^requirements/test\.(in|txt)$ files: ^requirements/test\.(in|txt)$
- id: pip-compile - id: pip-compile
alias: pip-compile-rocm alias: pip-compile-rocm

View File

@@ -309,7 +309,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "v4.4.2") set(CUTLASS_REVISION "v4.2.1")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -340,8 +340,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/cutlass_extensions/common.cpp" "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu") "csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}" SRCS "${VLLM_EXT_SRC}"
@@ -488,6 +490,185 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures") " in CUDA target architectures")
endif() endif()
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# CUTLASS MLA Archs and flags # CUTLASS MLA Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
@@ -512,6 +693,55 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MLA_ARCHS) set(MLA_ARCHS)
endif() endif()
# CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
@@ -557,6 +787,36 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"in CUDA target architectures.") "in CUDA target architectures.")
endif() endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
# #
# Machete kernels # Machete kernels
@@ -627,6 +887,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
endif() endif()
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
# Hadacore kernels # Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
@@ -676,12 +964,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
# #
set(VLLM_STABLE_EXT_SRC set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp" "csrc/libtorch_stable/torch_bindings.cpp")
"csrc/cutlass_extensions/common.cpp"
"csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC list(APPEND VLLM_STABLE_EXT_SRC
@@ -696,299 +979,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}") CUDA_ARCHS "${CUDA_ARCHS}")
endif() endif()
#
# CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch)
#
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# CUTLASS MoE kernels (moved from _C to _C_stable_libtorch)
#
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
#
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
#
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
#
# W4A8 kernels (moved from _C to _C_stable_libtorch)
#
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
message(STATUS "Enabling C_stable extension.") message(STATUS "Enabling C_stable extension.")
define_extension_target( define_extension_target(
_C_stable_libtorch _C_stable_libtorch
@@ -997,7 +987,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SOURCES ${VLLM_STABLE_EXT_SRC} SOURCES ${VLLM_STABLE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
@@ -1011,10 +1000,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Needed to use cuda APIs from C-shim # Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA) USE_CUDA)
# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
endif() endif()
# #
@@ -1030,6 +1015,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu" "csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu" "csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/gpt_oss_router_gemm.cu"
"csrc/moe/router_gemm.cu") "csrc/moe/router_gemm.cu")
endif() endif()

View File

@@ -1,264 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark: Fused FP8 output quantization in merge_attn_states
Compares fused vs unfused approaches for producing FP8-quantized merged
attention output:
1. Fused CUDA -- single CUDA kernel (merge + FP8 quant)
2. Fused Triton -- single Triton kernel (merge + FP8 quant)
3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant
4. Unfused Triton -- Triton merge + torch.compiled FP8 quant
Usage:
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16
"""
import argparse
import itertools
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
)
# ---------------------------------------------------------------------------
# Configuration defaults
# ---------------------------------------------------------------------------
NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096]
# (label, num_heads, head_size) — num_heads is for TP=1
HEAD_CONFIGS = [
("DeepSeek-V3 MLA", 128, 128),
("Llama-70B", 64, 128),
("Llama-8B", 32, 128),
]
TP_SIZES = [1, 2, 4, 8]
INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
QUANTILES = [0.5, 0.2, 0.8]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def short_dtype(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")
def make_inputs(
num_tokens: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
"""Create random prefix/suffix outputs and LSEs."""
prefix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
suffix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
# Sprinkle some inf values to exercise edge-case paths
mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
prefix_lse[mask] = float("inf")
mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
suffix_lse[mask2] = float("inf")
return prefix_output, suffix_output, prefix_lse, suffix_lse
def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes):
"""Build (num_tokens, num_heads, head_size, dtype_str) config tuples,
applying TP division to num_heads and skipping invalid combos."""
configs = []
for (_, nh, hs), nt, dtype, tp in itertools.product(
head_configs, num_tokens_list, input_dtypes, tp_sizes
):
nh_tp = nh // tp
if nh_tp >= 1:
configs.append((nt, nh_tp, hs, short_dtype(dtype)))
return configs
def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark merge_attn_states fused FP8 quantization"
)
parser.add_argument(
"--num-tokens",
type=int,
nargs="+",
default=None,
help=f"Override token counts (default: {NUM_TOKENS_LIST})",
)
parser.add_argument(
"--tp",
type=int,
nargs="+",
default=None,
help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})",
)
parser.add_argument(
"--dtype",
type=str,
nargs="+",
default=None,
help="Input dtypes (e.g. bfloat16 float16 float32). "
f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}",
)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Parse args and build configs before decorators
# ---------------------------------------------------------------------------
args = parse_args()
num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST
tp_sizes = args.tp if args.tp else TP_SIZES
if args.dtype:
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype]
else:
input_dtypes = INPUT_DTYPES
configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes)
torch._dynamo.config.recompile_limit = 8888
# ---------------------------------------------------------------------------
# Benchmark function
# ---------------------------------------------------------------------------
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_heads", "head_size", "dtype_str"],
x_vals=configs,
line_arg="provider",
line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"],
line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"],
styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")],
ylabel="us",
plot_name="merge_attn_states FP8 (fused vs unfused)",
args={},
)
)
@default_vllm_config()
def benchmark(num_tokens, num_heads, head_size, dtype_str, provider):
input_dtype = getattr(torch, dtype_str)
fp8_dtype = current_platform.fp8_dtype()
prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs(
num_tokens, num_heads, head_size, input_dtype
)
output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda")
if provider == "fused_cuda":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_cuda(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "fused_triton":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_triton(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "unfused_cuda":
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_cuda(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
else: # unfused_triton
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_triton(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
device_name = current_platform.get_device_name()
print(f"Device: {device_name}")
print(f"Token counts: {num_tokens_list}")
print(f"TP sizes: {tp_sizes}")
print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}")
print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}")
benchmark.run(print_data=True)
if __name__ == "__main__":
with torch.inference_mode():
main()

View File

@@ -1,211 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from itertools import product
import torch
import torch.nn.functional as F
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
dtype: torch.dtype
group_size: int # Changed from list[int] to int
def description(self):
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x DT {self.dtype} "
f"x GS {self.group_size}"
)
def get_bench_params() -> list[bench_params_t]:
"""Test configurations covering common model sizes."""
NUM_TOKENS = [16, 128, 512, 2048]
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
DTYPES = [torch.float16, torch.bfloat16]
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
)
return bench_params
# Reference implementations
def unfused_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then per-tensor quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Per-tensor quantize (no group_size used here)
silu_out, _ = ops.scaled_fp8_quant(silu_out)
def unfused_groupwise_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then group-wise quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Group quantize - use group_size directly
silu_out, _ = per_token_group_quant_fp8(
silu_out, group_size=group_size, use_ue8m0=False
)
def fused_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
):
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
out, _ = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=quant_dtype,
is_scale_transposed=False,
)
# Bench functions
def bench_fn(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1
globals = {
"x": x,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(x, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
"""Run benchmarks for all implementations."""
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
scale = 1 / params.hidden_size
x = (
torch.randn(
params.num_tokens,
params.hidden_size * 2,
dtype=params.dtype,
device="cuda",
)
* scale
)
timers = []
# Unfused per-tensor FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# Unfused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)
# Fused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
"fused_groupwise_fp8_impl",
)
)
return timers
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def main():
torch.set_default_device("cuda")
bench_params = get_bench_params()
print(f"Running {len(bench_params)} benchmark configurations...")
print(
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
)
print()
timers = []
for bp in tqdm(bench_params):
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
timers.extend(result_timers)
print("\n" + "=" * 80)
print("FINAL COMPARISON - ALL RESULTS")
print("=" * 80)
print_timers(timers)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
def get_batch_size_range(max_batch_size):
return [2**x for x in range(14) if 2**x <= max_batch_size]
def get_model_params(config):
if config.architectures[0] in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
):
num_experts = config.n_routed_experts
hidden_size = config.hidden_size
elif config.architectures[0] in ("GptOssForCausalLM",):
num_experts = config.num_local_experts
hidden_size = config.hidden_size
else:
raise ValueError(f"Unsupported architecture: {config.architectures}")
return num_experts, hidden_size
def get_benchmark(model, max_batch_size, trust_remote_code):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=get_batch_size_range(max_batch_size),
x_log=False,
line_arg="provider",
line_vals=[
"torch",
"vllm",
],
line_names=["PyTorch", "vLLM"],
styles=([("blue", "-"), ("red", "-")]),
ylabel="TFLOPs",
plot_name=f"{model} router gemm throughput",
args={},
)
)
def benchmark(batch_size, provider):
config = get_config(model=model, trust_remote_code=trust_remote_code)
num_experts, hidden_size = get_model_params(config)
mat_a = torch.randn(
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn(
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
bias = torch.randn(
num_experts, dtype=torch.bfloat16, device="cuda"
).contiguous()
is_hopper_or_blackwell = current_platform.is_device_capability(
90
) or current_platform.is_device_capability_family(100)
allow_dsv3_router_gemm = (
is_hopper_or_blackwell
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
)
allow_gpt_oss_router_gemm = (
is_hopper_or_blackwell
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
has_bias = False
if allow_gpt_oss_router_gemm:
has_bias = True
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
def runner():
if has_bias:
F.linear(mat_a, mat_b, bias)
else:
F.linear(mat_a, mat_b)
elif provider == "vllm":
def runner():
if allow_dsv3_router_gemm:
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
elif allow_gpt_oss_router_gemm:
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
else:
raise ValueError("Unsupported router gemm")
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
runner, quantiles=quantiles
)
def tflops(t_ms):
flops = 2 * batch_size * hidden_size * num_experts
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
parser.add_argument("--max-batch-size", default=16, type=int)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
# Get the benchmark function
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
# Run performance benchmark
benchmark.run(print_data=True)

View File

@@ -1,162 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Benchmarks the fused Triton bilinear position-embedding kernel against
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
#
# == Usage Examples ==
#
# Default benchmark:
# python3 benchmark_vit_bilinear_pos_embed.py
#
# Custom parameters:
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
import itertools
import torch
from vllm.model_executor.models.qwen3_vl import (
pos_embed_interpolate_native,
triton_pos_embed_interpolate,
)
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# (h, w) configurations to benchmark
h_w_configs = [
(16, 16),
(32, 32),
(48, 48),
(64, 64),
(128, 128),
(32, 48),
(60, 80),
]
# Temporal dimensions
t_range = [1]
configs = list(itertools.product(t_range, h_w_configs))
def get_benchmark(
num_grid_per_side: int,
spatial_merge_size: int,
hidden_dim: int,
dtype: torch.dtype,
device: str,
):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["t", "h_w"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["native", "triton"],
line_names=["Native (PyTorch)", "Triton"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name=(
f"vit-bilinear-pos-embed-"
f"grid{num_grid_per_side}-"
f"dim{hidden_dim}-"
f"{dtype}"
),
args={},
)
)
def benchmark(t, h_w, provider):
h, w = h_w
torch.manual_seed(42)
embed_weight = (
torch.randn(
num_grid_per_side * num_grid_per_side,
hidden_dim,
device=device,
dtype=dtype,
)
* 0.25
)
quantiles = [0.5, 0.2, 0.8]
if provider == "native":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: pos_embed_interpolate_native(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
else:
assert HAS_TRITON, "Triton not available"
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_pos_embed_interpolate(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark bilinear position embedding interpolation."
)
parser.add_argument(
"--num-grid-per-side",
type=int,
default=48,
help="Position embedding grid size (default: 48 for Qwen3-VL)",
)
parser.add_argument(
"--spatial-merge-size",
type=int,
default=2,
help="Spatial merge size (default: 2)",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=1152,
help="Embedding hidden dimension (default: 1152 for Qwen3-VL)",
)
parser.add_argument(
"--device",
type=str,
choices=["cuda:0", "cuda:1"],
default="cuda:0",
)
parser.add_argument(
"--save-path",
type=str,
default="./vit_pos_embed/",
)
args = parser.parse_args()
dtype = torch.bfloat16
bench = get_benchmark(
args.num_grid_per_side,
args.spatial_merge_size,
args.hidden_dim,
dtype,
args.device,
)
bench.run(print_data=True, save_path=args.save_path)

View File

@@ -373,7 +373,6 @@ if (ENABLE_X86_ISA)
"csrc/cpu/sgl-kernels/gemm.cpp" "csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp" "csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp" "csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
"csrc/cpu/sgl-kernels/moe.cpp" "csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp" "csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp") "csrc/cpu/sgl-kernels/moe_fp8.cpp")

View File

@@ -39,7 +39,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG c0ec424fd8a546d0cbbf4bf050bbcfe837c55afb GIT_TAG 29210221863736a08f71a866459e368ad1ac4a95
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@@ -3,33 +3,22 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <algorithm> #include <algorithm>
#include <limits>
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#include "../quantization/w8a8/fp8/common.cuh"
#include "../dispatch_utils.h"
namespace vllm { namespace vllm {
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case) // can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, typename output_t, const uint NUM_THREADS, template <typename scalar_t, const uint NUM_THREADS>
bool USE_FP8_OUTPUT>
__global__ void merge_attn_states_kernel( __global__ void merge_attn_states_kernel(
output_t* output, float* output_lse, const scalar_t* prefix_output, scalar_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output, const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads, const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size, const uint prefix_head_stride, const uint head_size, const uint prefix_head_stride,
const uint output_head_stride, const uint prefix_num_tokens, const uint output_head_stride) {
const float* output_scale) { using pack_128b_t = uint4;
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
// Outputs store pack_size elements of output_t, which is smaller for FP8.
using input_pack_t = uint4;
using output_pack_t =
std::conditional_t<USE_FP8_OUTPUT,
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
uint4>;
const uint pack_size = 16 / sizeof(scalar_t); const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size; const uint threads_per_head = head_size / pack_size;
@@ -52,45 +41,8 @@ __global__ void merge_attn_states_kernel(
head_idx * output_head_stride; head_idx * output_head_stride;
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset; const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
output_t* output_head_ptr = output + dst_head_offset; scalar_t* output_head_ptr = output + dst_head_offset;
// Pre-invert scale: multiplication is faster than division
float fp8_scale_inv = 1.0f;
if constexpr (USE_FP8_OUTPUT) {
fp8_scale_inv = 1.0f / *output_scale;
}
// If token_idx >= prefix_num_tokens, just copy from suffix
if (token_idx >= prefix_num_tokens) {
if (pack_offset < head_size) {
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
suffix_head_ptr)[pack_offset / pack_size];
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
}
}
if (output_lse != nullptr && pack_idx == 0) {
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
output_lse[head_idx * num_tokens + token_idx] = s_lse;
}
return;
}
// For tokens within prefix range, merge prefix and suffix
float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse; p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
@@ -101,34 +53,20 @@ __global__ void merge_attn_states_kernel(
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf; /* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
continuing the pipeline then yields NaN. Root cause: with chunked prefill continuing the pipeline then yields NaN. Root cause: with chunked prefill
a batch may be split into two chunks; if a request in that batch has no a batch may be split into two chunks; if a request in that batch has no
prefix hit, every LSE entry for that request's position is -inf, and at prefix hit, every LSE entry for that requests position is -inf, and at
this moment we merge cross-attention at first. For now we simply emit this moment we merge cross-attention at first. For now we simply emit
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
this problem. this problem.
*/ */
if (std::isinf(max_lse)) { if (std::isinf(max_lse)) {
if (pack_offset < head_size) { if (pack_offset < head_size) {
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>( // Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size]; prefix_head_ptr)[pack_offset / pack_size];
if constexpr (USE_FP8_OUTPUT) { // Pack 128b storage
// Convert prefix values to FP8 (since -inf means no data, reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
// prefix_output is expected to be zeros) p_out_pack;
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
}
} }
// We only need to write to output_lse once per head. // We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) { if (output_lse != nullptr && pack_idx == 0) {
@@ -146,43 +84,30 @@ __global__ void merge_attn_states_kernel(
const float s_scale = s_se / out_se; const float s_scale = s_se / out_se;
if (pack_offset < head_size) { if (pack_offset < head_size) {
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>( // Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size]; prefix_head_ptr)[pack_offset / pack_size];
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>( pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_head_ptr)[pack_offset / pack_size]; suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;
// Compute merged values in float32
float o_out_f[pack_size];
#pragma unroll #pragma unroll
for (uint i = 0; i < pack_size; ++i) { for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f = const float p_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]); vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f = const float s_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]); vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale); // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
} }
// Convert and store // Pack 128b storage
if constexpr (USE_FP8_OUTPUT) { reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
output_t o_out_pack[pack_size]; o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
o_out_f[i], fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
output_pack_t o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f[i]);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
}
} }
// We only need to write to output_lse once per head. // We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) { if (output_lse != nullptr && pack_idx == 0) {
@@ -209,73 +134,50 @@ __global__ void merge_attn_states_kernel(
} \ } \
} }
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \ #define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
USE_FP8_OUTPUT) \
{ \ { \
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \ vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
USE_FP8_OUTPUT> \
<<<grid, block, 0, stream>>>( \ <<<grid, block, 0, stream>>>( \
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \ reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \ reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \ reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \ reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \ reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride, \ num_heads, head_size, prefix_head_stride, output_head_stride); \
prefix_num_tokens, output_scale_ptr); \
} }
/*@brief Merges the attention states from prefix and suffix /*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
* *
* @param output [n,h,d] The output tensor to store the merged attention states. * @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values. * @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states. * @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention * @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
* states. * states.
* @param suffix_output [n,h,d] The suffix attention states. * @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention * @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states. * states.
* @param prefill_tokens_with_context Number of prefill tokens with context
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
* is computed by merging prefix_output and suffix_output. For remaining tokens
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
* from suffix_output.
* @param output_scale Optional scalar tensor for FP8 static quantization.
* When provided, output must be FP8 dtype.
*/ */
template <typename scalar_t> template <typename scalar_t>
void merge_attn_states_launcher( void merge_attn_states_launcher(torch::Tensor& output,
torch::Tensor& output, std::optional<torch::Tensor> output_lse, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& prefix_output,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, const torch::Tensor& prefix_lse,
const std::optional<int64_t> prefill_tokens_with_context, const torch::Tensor& suffix_output,
const std::optional<torch::Tensor>& output_scale) { const torch::Tensor& suffix_lse) {
constexpr uint NUM_THREADS = 128; constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0); const uint num_tokens = output.size(0);
const uint num_heads = output.size(1); const uint num_heads = output.size(1);
const uint head_size = output.size(2); const uint head_size = output.size(2);
const uint prefix_head_stride = prefix_output.stride(1); const uint prefix_head_stride = prefix_output.stride(1);
const uint output_head_stride = output.stride(1); const uint output_head_stride = output.stride(1);
// Thread mapping is based on input BF16 pack_size
const uint pack_size = 16 / sizeof(scalar_t); const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0, TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size); "headsize must be multiple of pack_size:", pack_size);
const uint prefix_num_tokens =
prefill_tokens_with_context.has_value()
? static_cast<uint>(prefill_tokens_with_context.value())
: num_tokens;
TORCH_CHECK(prefix_num_tokens <= num_tokens,
"prefix_num_tokens must be <= num_tokens");
float* output_lse_ptr = nullptr; float* output_lse_ptr = nullptr;
if (output_lse.has_value()) { if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>(); output_lse_ptr = output_lse.value().data_ptr<float>();
} }
float* output_scale_ptr = nullptr;
if (output_scale.has_value()) {
output_scale_ptr = output_scale.value().data_ptr<float>();
}
// Process one pack elements per thread. for float, the // Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8. // pack_size is 4 for half/bf16, the pack_size is 8.
const uint threads_per_head = head_size / pack_size; const uint threads_per_head = head_size / pack_size;
@@ -287,22 +189,14 @@ void merge_attn_states_launcher(
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (output_scale.has_value()) { LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
// FP8 output path - dispatch on output FP8 type
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
});
} else {
// Original BF16/FP16/FP32 output path
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
}
} }
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ #define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \ { \
merge_attn_states_launcher<scalar_t>( \ merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
output, output_lse, prefix_output, prefix_lse, suffix_output, \ prefix_lse, suffix_output, \
suffix_lse, prefill_tokens_with_context, output_scale); \ suffix_lse); \
} }
void merge_attn_states(torch::Tensor& output, void merge_attn_states(torch::Tensor& output,
@@ -310,21 +204,6 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse, const torch::Tensor& suffix_lse) {
std::optional<int64_t> prefill_tokens_with_context, DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
const std::optional<torch::Tensor>& output_scale) {
if (output_scale.has_value()) {
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
"output must be FP8 when output_scale is provided, got: ",
output.scalar_type());
} else {
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
"output dtype (", output.scalar_type(),
") must match prefix_output dtype (",
prefix_output.scalar_type(), ") when output_scale is not set");
}
// Always dispatch on prefix_output (input) dtype
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
CALL_MERGE_ATTN_STATES_LAUNCHER);
} }

View File

@@ -10,10 +10,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes, int64_t block_size_in_bytes,
const torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,

View File

@@ -24,8 +24,6 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else
#include <cuda.h>
#endif #endif
#if defined(__gfx942__) #if defined(__gfx942__)
@@ -75,59 +73,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
} }
} }
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
if (n == 0) return;
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
const int64_t* size_data = sizes.data_ptr<int64_t>();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
#else
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
#endif
}
namespace vllm { namespace vllm {
// Grid: (num_layers, num_pairs) // Grid: (num_layers, num_pairs)

View File

@@ -30,15 +30,13 @@
}() }()
namespace { namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul }; enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
FusedMOEAct get_act_type(const std::string& act) { FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") { if (act == "silu") {
return FusedMOEAct::SiluAndMul; return FusedMOEAct::SiluAndMul;
} else if (act == "swigluoai") { } else if (act == "swigluoai") {
return FusedMOEAct::SwigluOAIAndMul; return FusedMOEAct::SwigluOAIAndMul;
} else if (act == "gelu") {
return FusedMOEAct::GeluAndMul;
} else { } else {
TORCH_CHECK(false, "Invalid act type: " + act); TORCH_CHECK(false, "Invalid act type: " + act);
} }
@@ -106,43 +104,6 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
} }
} }
template <typename scalar_t>
void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride, const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
vec_op::FP32Vec16 w1_vec(M_SQRT1_2);
vec_op::FP32Vec16 w2_vec(0.5);
alignas(64) float temp[16];
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto er_input_vec = gate_vec * w1_vec;
er_input_vec.save(temp);
for (int32_t i = 0; i < 16; ++i) {
temp[i] = std::erf(temp[i]);
}
vec_op::FP32Vec16 er_vec(temp);
auto gelu = gate_vec * w2_vec * (one_vec + er_vec);
auto gated_output_fp32 = up_vec * gelu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t> template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act, FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input, float* __restrict__ input,
@@ -157,9 +118,6 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
case FusedMOEAct::SiluAndMul: case FusedMOEAct::SiluAndMul:
silu_and_mul(input, output, m, n, input_stride, output_stride); silu_and_mul(input, output, m, n, input_stride, output_stride);
return; return;
case FusedMOEAct::GeluAndMul:
gelu_and_mul(input, output, m, n, input_stride, output_stride);
return;
default: default:
TORCH_CHECK(false, "Unsupported act type."); TORCH_CHECK(false, "Unsupported act type.");
} }

View File

@@ -8,7 +8,7 @@ Generate CPU attention dispatch switch cases and kernel instantiations.
import os import os
# Head dimensions divisible by 32 (support all ISAs) # Head dimensions divisible by 32 (support all ISAs)
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256, 512] HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
# Head dimensions divisible by 16 but not 32 (VEC16 only) # Head dimensions divisible by 16 but not 32 (VEC16 only)
HEAD_DIMS_16 = [80, 112] HEAD_DIMS_16 = [80, 112]

View File

@@ -117,14 +117,6 @@ inline void parallel_for(int n, const func_t& f) {
#endif #endif
} }
inline int get_thread_num() {
#if defined(_OPENMP)
return omp_get_thread_num();
#else
return 0;
#endif
}
// for 1d parallel, use `actual_nth` // for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42 // for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) { int inline adjust_num_threads(int m) {

View File

@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
template <typename T> inline bool can_use_brgemm(int M); template <typename T> inline bool can_use_brgemm(int M);
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; } template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; } template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
template <> inline bool can_use_brgemm<int8_t>(int M) { return M > 4; } // TODO: add u8s8 brgemm, this requires PyTorch 2.7
template <> inline bool can_use_brgemm<uint8_t>(int M) { return M > 4; } template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; } template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; } template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
@@ -40,17 +40,9 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K; return use_int8_w8a8 ? K + sizeof(int32_t) : K;
} }
inline int64_t get_4bit_block_k_size(int64_t group_size) { // pack weight to vnni format
return group_size > 128 ? 128 : group_size;
}
// pack weight into vnni format
at::Tensor convert_weight_packed(at::Tensor& weight); at::Tensor convert_weight_packed(at::Tensor& weight);
// pack weight to vnni format for int4 (adapted from sglang)
std::tuple<at::Tensor, at::Tensor, at::Tensor>
convert_weight_packed_scale_zp(at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
// moe implementations for int8 w8a8 // moe implementations for int8 w8a8
template <typename scalar_t> template <typename scalar_t>
void fused_experts_int8_kernel_impl( void fused_experts_int8_kernel_impl(
@@ -241,31 +233,6 @@ void tinygemm_kernel(
int64_t strideBs, int64_t strideBs,
bool brg); bool brg);
// int4 scaled GEMM (adapted from sglang)
at::Tensor int4_scaled_mm_cpu(
at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, at::Tensor& w_scales, std::optional<at::Tensor> bias);
// int4 tinygemm kernel interface(adapted from sglang)
template <typename scalar_t>
void tinygemm_kernel(
scalar_t* C,
float* C_temp,
const uint8_t* A,
const float* scales_a,
const int32_t* qzeros_a,
const uint8_t* B,
const float* scales_b,
const int8_t* qzeros_b,
const int32_t* compensation,
int8_t* dqB_tmp,
int64_t M,
int64_t K,
int64_t lda,
int64_t ldc_f,
int64_t ldc_s,
bool store_out,
bool use_brgemm);
// TODO: debug print, remove me later // TODO: debug print, remove me later
inline void print_16x32i(const __m512i x) { inline void print_16x32i(const __m512i x) {
int32_t a[16]; int32_t a[16];

View File

@@ -1,755 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Adapted from sgl-project/sglang
// https://github.com/sgl-project/sglang/pull/8226
#include <ATen/ATen.h>
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
#define BLOCK_N block_size_n()
#define BLOCK_M 128
template <bool sym_quant_act>
struct ActDtype;
template <>
struct ActDtype<true> {
using type = int8_t;
};
template <>
struct ActDtype<false> {
using type = uint8_t;
};
struct alignas(32) m256i_wrapper {
__m256i data;
};
#if defined(CPU_CAPABILITY_AVX512)
inline std::array<m256i_wrapper, 2> load_zps_4vnni(
const int8_t* __restrict__ zps) {
__m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps));
__m256i vzps_high =
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps + 8));
__m256i shuffle_mask =
_mm256_set_epi8(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3,
3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask);
vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask);
m256i_wrapper vzps_low_wp, vzps_high_wp;
vzps_low_wp.data = vzps_low;
vzps_high_wp.data = vzps_high;
return {vzps_low_wp, vzps_high_wp};
}
inline std::array<m256i_wrapper, 2> load_uint4_as_int8(
const uint8_t* __restrict__ qB) {
__m256i packed = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(qB));
const __m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i high = _mm256_srli_epi16(packed, 4);
high = _mm256_and_si256(high, low_mask);
__m256i low = _mm256_and_si256(packed, low_mask);
m256i_wrapper low_wp, high_wp;
low_wp.data = low;
high_wp.data = high;
return {low_wp, high_wp};
}
template <int N, int ldb>
void _dequant_weight_zp_only(const uint8_t* __restrict__ B, int8_t* dqB,
const int8_t* __restrict__ qzeros, int64_t K) {
#pragma GCC unroll 2
for (int n = 0; n < N; n += 16) {
auto [zps_low_wp, zps_high_wp] = load_zps_4vnni(&qzeros[n]);
auto zps_low = zps_low_wp.data;
auto zps_high = zps_high_wp.data;
for (int k = 0; k < K; k += 4) {
auto [vb_low_wp, vb_high_wp] =
load_uint4_as_int8(B + ldb * k + n / 2 * 4);
auto vb_low = vb_low_wp.data;
auto vb_high = vb_high_wp.data;
vb_high = _mm256_sub_epi8(vb_high, zps_high);
vb_low = _mm256_sub_epi8(vb_low, zps_low);
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4),
vb_low);
_mm256_storeu_si256(
reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high);
}
}
}
template <bool sym_quant_act, int N, bool accum>
void _dequant_and_store(float* __restrict__ output,
const int32_t* __restrict__ input,
const float* __restrict__ scale_a,
const int32_t* __restrict__ zp_a,
const float* __restrict__ scale_b,
const int32_t* __restrict__ comp_b, int M, int ldi,
int ldo, int ldsa = 1) {
for (int m = 0; m < M; ++m) {
float a_scale = *(scale_a + m * ldsa);
__m512 va_scale = _mm512_set1_ps(a_scale);
int32_t a_zp;
__m512i va_zp;
if constexpr (!sym_quant_act) {
a_zp = *(zp_a + m * ldsa);
va_zp = _mm512_set1_epi32(a_zp);
}
int n = 0;
#pragma GCC unroll 2
for (; n < N; n += 16) {
__m512i vc = _mm512_loadu_si512(input + m * ldi + n);
if constexpr (!sym_quant_act) {
__m512i vb_comp = _mm512_loadu_si512(comp_b + n);
vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp));
}
__m512 vc_f = _mm512_cvtepi32_ps(vc);
__m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale);
__m512 vb_s = _mm512_loadu_ps(scale_b + n);
vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s);
if constexpr (accum) {
__m512 vo = _mm512_loadu_ps(output + m * ldo + n);
_mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul));
} else {
_mm512_storeu_ps(output + m * ldo + n, vc_f_mul);
}
}
for (; n < N; ++n) {
float dq_val;
if constexpr (sym_quant_act) {
dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n];
} else {
dq_val = (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale *
scale_b[n];
}
if constexpr (accum) {
output[m * ldo + n] += dq_val;
} else {
output[m * ldo + n] = dq_val;
}
}
}
}
#else
template <int N, int ldb>
void _dequant_weight_zp_only(const uint8_t* B, int8_t* dqB,
const int8_t* qzeros, int64_t K) {
for (int k = 0; k < K; ++k) {
for (int n = 0; n < N / 2; ++n) {
int32_t b = (int32_t)B[k * ldb + n];
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n];
dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n];
}
}
}
#endif
#if defined(CPU_CAPABILITY_AVX512)
inline __m512i combine_m256i(__m256i a, __m256i b) {
__m512i c = _mm512_castsi256_si512(a);
return _mm512_inserti64x4(c, b, 1);
}
inline __m512i combine_m256i(std::array<m256i_wrapper, 2> two_256) {
return combine_m256i(two_256[0].data, two_256[1].data);
}
static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) {
__m512i zero = _mm512_setzero_si512();
__mmask64 blt0 = _mm512_movepi8_mask(b);
return _mm512_mask_sub_epi8(a, blt0, zero, a);
}
template <bool sym_quant_act, int M, int N, int ldb>
void _dequant_gemm_accum_small_M(float* __restrict__ C, const uint8_t* A,
const float* scales_a, const int32_t* qzeros_a,
const uint8_t* B, const float* scales_b,
const int8_t* qzeros_b, int64_t K, int64_t lda,
int64_t ldc) {
constexpr int COLS = N / 16;
__m512i ones = _mm512_set1_epi8(1);
__m512i va;
__m512i vb[COLS];
__m512i vc[M * COLS];
__m512 vscales[COLS];
__m512i vzps[COLS];
__m512i vcompensate[COLS];
Unroll<COLS>{}([&](auto i) {
vscales[i] = _mm512_loadu_ps(scales_b + i * 16);
vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16));
if constexpr (!sym_quant_act) {
vcompensate[i] = _mm512_setzero_epi32();
}
});
Unroll<M * COLS>{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); });
auto compute = [&](auto i, int k) {
constexpr const int row = i / COLS;
constexpr const int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k));
}
if constexpr (row == 0) {
int B_offset = k * ldb + col * 16 * 2;
vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset));
vb[col] = _mm512_sub_epi8(vb[col], vzps[col]);
if constexpr (!sym_quant_act) {
vcompensate[col] = _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]);
}
_mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0);
}
if constexpr (sym_quant_act) {
auto vsb = _mm512_sign_epi8(vb[col], va);
auto vabsa = _mm512_sign_epi8(va, va);
vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb);
} else {
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
}
};
constexpr const int unroll = 4;
int k = 0;
for (; k < K / 4 / unroll; k++) {
Unroll<unroll>{}(
[&](auto i) { Unroll<M * COLS>{}(compute, 4 * (k * unroll + i)); });
}
k *= 4 * unroll;
for (; k < K; k += 4) {
Unroll<M * COLS>{}(compute, k);
}
auto store = [&](auto i) {
constexpr const int row = i / COLS;
constexpr const int col = i % COLS;
__m512 vc_float;
if constexpr (!sym_quant_act) {
vc[i] = _mm512_sub_epi32(
vc[i], _mm512_mullo_epi32(vcompensate[col],
_mm512_set1_epi32(*(qzeros_a + row))));
}
vc_float = _mm512_cvtepi32_ps(vc[i]);
vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row)));
vc_float = _mm512_mul_ps(vc_float, vscales[col]);
auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16);
vc_float = _mm512_add_ps(vc_float, vc_old);
_mm512_storeu_ps(C + row * ldc + col * 16, vc_float);
};
Unroll<M * COLS>{}(store);
}
#define CALL_DEQUANT_GEMM_ACCUM_SMALL_M(M) \
_dequant_gemm_accum_small_M<sym_quant_act, M, N, ldb>( \
C, A, scales_a, qzeros_a, B, scales_b, qzeros_b, K, lda, ldc);
#endif
template <bool sym_quant_act, int N, int ldb>
void _dequant_gemm_accum(float* C, const uint8_t* A, const float* scales_a,
const int32_t* qzeros_a, const uint8_t* B,
const float* scales_b, const int8_t* qzeros_b,
const int32_t* compensation, int8_t* dqB, int64_t M,
int64_t K, int64_t lda, int64_t ldc, bool use_brgemm) {
#if defined(CPU_CAPABILITY_AVX512)
if (!use_brgemm) {
switch (M) {
case 1:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(1);
break;
case 2:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(2);
break;
case 3:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(3);
break;
case 4:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(4);
break;
default:
TORCH_CHECK(false, "tinygemm_kernel: unexpected M for AVX path!");
}
return;
}
_dequant_weight_zp_only<N, ldb>(B, dqB, qzeros_b, K);
using Tin = typename ActDtype<sym_quant_act>::type;
Tin* A_ptr = (Tin*)A;
if (use_brgemm) {
int32_t C_i32[M * N];
at::native::cpublas::brgemm(M, N, K, lda, N /*ldb*/, N /*ldc*/,
false /* add_C */, A_ptr, dqB, C_i32,
true /* is_vnni */);
_mm_prefetch(B + N * K / 2, _MM_HINT_T0);
_mm_prefetch(A + K, _MM_HINT_T0);
_dequant_and_store<sym_quant_act, N, true>(C, C_i32, scales_a, qzeros_a,
scales_b, compensation, M,
N /*ldi*/, ldc, 1 /*ldsa*/);
} else
#endif
{
TORCH_CHECK(false, "tinygemm_kernel: scalar path not implemented!");
}
}
template <int N>
inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) {
if (bias_ptr) {
for (int i = 0; i < m; ++i) {
int j = 0;
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 bias_vec = _mm512_loadu_ps(bias_ptr + j);
_mm512_storeu_ps(y_buf + i * N + j, bias_vec);
}
#endif
for (; j < N; ++j) {
y_buf[i * N + j] = bias_ptr[j];
}
}
} else {
for (int i = 0; i < m; ++i) {
int j = 0;
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 zero_vec = _mm512_setzero_ps();
_mm512_storeu_ps(y_buf + i * N + j, zero_vec);
}
#endif
for (; j < N; ++j) {
y_buf[i * N + j] = 0;
}
}
}
}
template <int N, typename out_dtype>
inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m,
int64_t lda) {
for (int i = 0; i < m; ++i) {
int j = 0;
if constexpr (std::is_same<out_dtype, float>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
_mm512_storeu_ps(c_ptr + i * lda + j, y_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = y_buf[i * N + j];
}
} else if constexpr (std::is_same<out_dtype, at::BFloat16>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
__m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
y_bf16_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]);
}
} else if constexpr (std::is_same<out_dtype, at::Half>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
__m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
y_fp16_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]);
}
} else {
TORCH_CHECK(false, "Unsupported output dtype");
}
}
}
void fill_val_stub(int32_t* __restrict__ output, int32_t value, int64_t size) {
using iVec = at::vec::Vectorized<int32_t>;
constexpr int VecSize = iVec::size();
const iVec fill_val_vec = iVec(value);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - VecSize; d += VecSize) {
fill_val_vec.store(output + d);
}
for (; d < size; ++d) {
output[d] = value;
}
}
template <bool sym_quant_act, typename act_dtype, typename out_dtype>
void _da8w4_linear_impl(
act_dtype* __restrict__ input, const float* __restrict__ input_scales,
const int32_t* __restrict__ input_qzeros,
const uint8_t* __restrict__ weight, const float* __restrict__ weight_scales,
const int8_t* __restrict__ weight_qzeros, const float* __restrict__ bias,
out_dtype* __restrict__ output, float* __restrict__ output_temp,
int8_t* __restrict__ dequant_weight_temp, int64_t M, int64_t N, int64_t K,
int64_t num_groups) {
const bool use_brgemm = can_use_brgemm<act_dtype>(M);
int64_t block_m = [&]() -> long {
if (M <= 48) {
return M;
} else if (M < 64) {
return 32;
} else if (M < 96) {
return 64;
} else {
return 128;
}
}();
int64_t Mc = div_up(M, block_m);
bool parallel_on_M = M > 128;
int64_t Nc = N / BLOCK_N;
int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc;
int64_t group_size = div_up(K, num_groups);
int64_t _block_k = get_4bit_block_k_size(group_size);
int64_t Kc = K / _block_k;
int64_t block_per_group = group_size / _block_k;
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
int tid = get_thread_num();
float* C_tmp = output_temp + tid * block_m * BLOCK_N;
int8_t* dqB_tmp = dequant_weight_temp + tid * _block_k * BLOCK_N;
for (const auto i : c10::irange(begin, end)) {
int64_t mc = parallel_on_M ? i / Nc : 0;
int64_t nc = parallel_on_M ? i % Nc : i;
int64_t mc_end = parallel_on_M ? mc + 1 : Mc;
for (int mci = mc; mci < mc_end; ++mci) {
int64_t m_size =
mci * block_m + block_m > M ? M - mci * block_m : block_m;
auto bias_data = bias ? bias + nc * BLOCK_N : nullptr;
copy_bias<BLOCK_N>(bias_data, C_tmp, m_size);
for (int kci = 0; kci < Kc; ++kci) {
int32_t* compensation_ptr =
sym_quant_act
? nullptr
: (int32_t*)(void*)(weight +
(nc * Kc + kci) *
(BLOCK_N *
(_block_k / 2 + sizeof(int32_t))) +
_block_k * BLOCK_N / 2);
_dequant_gemm_accum<sym_quant_act, BLOCK_N, BLOCK_N / 2>(
/*C*/ C_tmp,
/*A*/ (uint8_t*)input + mci * block_m * K + kci * _block_k,
/*scales_a*/ input_scales + mci * block_m,
/*qzeros_a*/ input_qzeros + mci * block_m,
/*B*/ weight + (nc * Kc + kci) *
(BLOCK_N * (_block_k / 2 + sizeof(int32_t))),
/*scales_b*/ weight_scales + nc * BLOCK_N * num_groups +
kci / block_per_group * BLOCK_N,
/*qzeros_b*/ weight_qzeros + nc * BLOCK_N * num_groups +
kci / block_per_group * BLOCK_N,
/*Bcomp*/ compensation_ptr,
/*dqB_tmp*/ dqB_tmp,
/*M*/ m_size,
/*K*/ _block_k,
/*lda*/ K,
/*ldc*/ BLOCK_N,
/*use_brgemm*/ use_brgemm);
}
store_out<BLOCK_N>(C_tmp, output + mci * block_m * N + nc * BLOCK_N,
m_size, N /*lda*/);
}
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
} // anonymous namespace
std::tuple<at::Tensor, at::Tensor, at::Tensor>
convert_int4_weight_packed_with_compensation(const at::Tensor& weight,
const at::Tensor& scales,
const at::Tensor& qzeros) {
TORCH_CHECK(weight.dim() == 2,
"DA8W4 CPU: Weight should be a 2D tensor for packing");
TORCH_CHECK(
weight.size(1) % 2 == 0,
"DA8W4 CPU: Weight should have even number of columns for packing");
auto new_scales = scales;
auto new_qzeros = qzeros;
if (new_scales.dim() == 1) {
new_scales.unsqueeze_(1);
}
new_scales = new_scales.to(at::kFloat);
if (new_qzeros.dim() == 1) {
new_qzeros.unsqueeze_(1);
}
new_qzeros = new_qzeros.to(at::kChar);
int64_t N = weight.size(0);
int64_t K = weight.size(1);
int64_t G = scales.size(1);
int64_t group_size = K / G;
int64_t _block_k = get_4bit_block_k_size(group_size);
constexpr int block_n = block_size_n();
int64_t Nc = N / block_n;
int64_t Kc = K / _block_k;
auto weight_view = weight.view({Nc, block_n, Kc, _block_k});
at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous();
at::Tensor blocked_weight;
at::Tensor blocked_scales =
new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
at::Tensor blocked_qzeros =
new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) -
new_qzeros.view({Nc, block_n, G, -1});
weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, _block_k});
at::Tensor compensation = weight_sub_qzero.sum(-1);
compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt);
int64_t buffer_size_nbytes =
_block_k * block_n / 2 + block_n * sizeof(int32_t);
blocked_weight = at::empty({Nc, Kc, buffer_size_nbytes}, weight.options());
auto weight_ptr = weight_reordered.data_ptr<uint8_t>();
auto compensation_ptr = compensation.data_ptr<int32_t>();
auto blocked_weight_ptr = blocked_weight.data_ptr<uint8_t>();
int64_t num_blocks = Nc * Kc;
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
auto in_ptr = weight_ptr + i * _block_k * block_n;
auto out_ptr =
blocked_weight_ptr + i * block_n * (_block_k / 2 + sizeof(int32_t));
int32_t* comp_in_prt = compensation_ptr + i * block_n;
int32_t* comp_out_prt =
(int32_t*)(void*)(blocked_weight_ptr +
i * block_n * (_block_k / 2 + sizeof(int32_t)) +
_block_k * block_n / 2);
constexpr int n_group_size = 8;
constexpr int vnni_size = 4;
constexpr int n_group = block_n / n_group_size;
for (int nb = 0; nb < n_group; nb += 2) {
for (int k = 0; k < _block_k; k += vnni_size) {
for (int ni = 0; ni < n_group_size; ++ni) {
for (int ki = 0; ki < vnni_size; ++ki) {
int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n;
int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n;
int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size +
k * block_n / 2 + ki;
uint8_t src_1 = *(in_ptr + src_idx_1);
uint8_t src_2 = *(in_ptr + src_idx_2);
uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4);
*(out_ptr + dst_idx) = dst;
}
}
}
}
for (int nb = 0; nb < block_n; nb++) {
*(comp_out_prt + nb) = *(comp_in_prt + nb);
}
}
});
return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales),
std::move(blocked_qzeros));
}
std::tuple<at::Tensor, at::Tensor> autoawq_to_int4pack(at::Tensor qweight,
at::Tensor qzeros) {
auto bitshifts = at::tensor({0, 4, 1, 5, 2, 6, 3, 7}, at::kInt) * 4;
auto qweight_unsq = qweight.unsqueeze(-1);
auto unpacked = at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF;
auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte);
auto qzeros_unsq = qzeros.unsqueeze(-1);
auto qzeros_unpacked = at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF;
auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte);
return std::make_tuple(qweight_final, qzeros_final);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales) {
auto res = autoawq_to_int4pack(qweight, qzeros);
auto _qweight = std::get<0>(res);
auto _qzeros = std::get<1>(res);
auto _scales = scales;
_qzeros = _qzeros.transpose(-2, -1).contiguous();
_scales = _scales.transpose(-2, -1).contiguous();
if (_qweight.dim() == 3) {
int64_t E = _qweight.size(0);
int64_t K = _qweight.size(2);
int64_t G = _scales.size(2);
int64_t group_size = K / G;
int64_t _block_k = get_4bit_block_k_size(group_size);
int64_t block_n = block_size_n();
int64_t Nc = _qweight.size(1) / block_n;
int64_t Kc = K / _block_k;
int64_t buffer_size_nbytes =
_block_k * block_n / 2 + block_n * sizeof(int32_t);
auto blocked_weight =
at::empty({E, Nc, Kc, buffer_size_nbytes}, _qweight.options());
auto blocked_scales =
at::empty({E, Nc, G, block_n}, _scales.options()).to(at::kFloat);
auto blocked_qzeros =
at::empty({E, Nc, G, block_n}, _qzeros.options()).to(at::kChar);
for (int i = 0; i < _qweight.size(0); i++) {
auto res_ = convert_int4_weight_packed_with_compensation(
_qweight[i], _scales[i], _qzeros[i]);
blocked_weight[i] = std::get<0>(res_);
blocked_scales[i] = std::get<1>(res_);
blocked_qzeros[i] = std::get<2>(res_);
}
_qweight = blocked_weight;
_scales = blocked_scales;
_qzeros = blocked_qzeros;
} else {
auto res_ = convert_int4_weight_packed_with_compensation(_qweight, _scales,
_qzeros);
_qweight = std::get<0>(res_);
_scales = std::get<1>(res_);
_qzeros = std::get<2>(res_);
}
return std::make_tuple(_qweight, _qzeros, _scales);
}
at::Tensor int4_scaled_mm_cpu_with_quant(const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& weight_scales,
const at::Tensor& weight_qzeros,
const std::optional<at::Tensor>& bias,
at::ScalarType output_dtype) {
RECORD_FUNCTION("vllm::int4_scaled_mm_cpu_with_quant",
std::vector<c10::IValue>({input, weight}));
int64_t M_a = input.size(0);
int64_t K_a = input.size(1);
int64_t lda = input.stride(0);
const auto st = input.scalar_type();
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf,
"int4_scaled_mm_cpu_with_quant: expect A to be bfloat16 or half.");
constexpr bool sym_quant_act = false;
using Tin = typename ActDtype<sym_quant_act>::type;
int64_t act_buffer_size =
M_a * K_a + M_a * sizeof(float) + M_a * sizeof(int32_t);
auto act_buffer =
at::empty({act_buffer_size}, input.options().dtype(at::kByte));
auto Aq_data = act_buffer.data_ptr<uint8_t>();
auto As_data = reinterpret_cast<float*>(Aq_data + M_a * K_a);
auto Azp_data = reinterpret_cast<int32_t*>(As_data + M_a);
fill_val_stub(Azp_data, 128, M_a);
auto out_sizes = input.sizes().vec();
int64_t N = weight_scales.size(0) * weight_scales.size(-1);
out_sizes.back() = N;
auto output = at::empty(out_sizes, input.options());
int64_t Nc = weight.size(0);
int64_t Kc = weight.size(1);
int64_t _block_k = K_a / Kc;
TORCH_CHECK(N == Nc * BLOCK_N, "DA8W4: weight and input shapes mismatch");
int64_t num_groups = weight_scales.size(1);
const uint8_t* b_ptr = weight.data_ptr<uint8_t>();
const float* b_scales_ptr = weight_scales.data_ptr<float>();
const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr<int8_t>();
const float* bias_ptr =
bias.has_value() ? bias.value().data_ptr<float>() : nullptr;
int num_threads = at::get_num_threads();
int64_t temp_buffer_size = num_threads * BLOCK_M * BLOCK_N * sizeof(float) +
num_threads * _block_k * BLOCK_N;
auto c_temp_buffer =
at::empty({temp_buffer_size}, input.options().dtype(at::kChar));
float* c_temp_ptr = (float*)((void*)(c_temp_buffer.data_ptr<int8_t>()));
int8_t* dqB_temp_ptr =
(int8_t*)((void*)(c_temp_ptr + num_threads * BLOCK_M * BLOCK_N));
#define LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act) \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, \
"int4_scaled_mm_cpu", [&] { \
const scalar_t* __restrict__ A_data = input.data_ptr<scalar_t>(); \
scalar_t* __restrict__ c_ptr = output.data_ptr<scalar_t>(); \
at::parallel_for(0, M_a, 0, [&](int64_t begin, int64_t end) { \
for (int64_t m = begin; m < end; ++m) { \
quantize_row_int8<scalar_t>(Aq_data + m * K_a, As_data[m], \
A_data + m * lda, K_a); \
} \
}); \
_da8w4_linear_impl<sym_quant_act, Tin, scalar_t>( \
Aq_data, As_data, Azp_data, b_ptr, b_scales_ptr, b_qzeros_ptr, \
bias_ptr, c_ptr, c_temp_ptr, dqB_temp_ptr, M_a, N, K_a, \
num_groups); \
});
LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act);
return output;
}
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out,
const float* __restrict__ input, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += Vec::size()) {
fVec x0 = fVec::loadu(input + d);
fVec x1 = fVec::loadu(input + d + fVec::size());
Vec res = convert_from_float_ext<scalar_t>(x0, x1);
res.store(out + d);
}
}
} // anonymous namespace
template <typename scalar_t>
void tinygemm_kernel(scalar_t* C, float* C_temp, const uint8_t* A,
const float* scales_a, const int32_t* qzeros_a,
const uint8_t* B, const float* scales_b,
const int8_t* qzeros_b, const int32_t* compensation,
int8_t* dqB_tmp, int64_t M, int64_t K, int64_t lda,
int64_t ldc_f, int64_t ldc_s, bool store_out,
bool use_brgemm) {
_dequant_gemm_accum<false, BLOCK_N, BLOCK_N / 2>(
C_temp, A, scales_a, qzeros_a, B, scales_b, qzeros_b, compensation,
dqB_tmp, M, K, lda, ldc_f, use_brgemm);
if (store_out) {
for (int64_t m = 0; m < M; ++m) {
copy_stub<scalar_t>(C + m * ldc_s, C_temp + m * ldc_f, BLOCK_N);
}
}
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
TYPE * C, float* C_temp, const uint8_t* A, const float* scales_a, \
const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, \
const int8_t* qzeros_b, const int32_t* compensation, int8_t* dqB_tmp, \
int64_t M, int64_t K, int64_t lda, int64_t ldc_f, int64_t ldc_s, \
bool store_out, bool use_brgemm)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
at::Tensor& w_scales,
std::optional<at::Tensor> bias) {
return int4_scaled_mm_cpu_with_quant(x, w, w_scales, w_zeros, bias,
x.scalar_type());
}

View File

@@ -79,14 +79,6 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni); at::ScalarType out_dtype, bool is_vnni);
// Adapted from sglang: INT4 W4A8 kernels
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
at::Tensor& w_scales,
std::optional<at::Tensor> bias);
torch::Tensor get_scheduler_metadata( torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q, const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim, const int64_t num_heads_kv, const int64_t head_dim,
@@ -293,18 +285,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
ops.impl("int8_scaled_mm_with_quant", torch::kCPU, ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
&int8_scaled_mm_with_quant); &int8_scaled_mm_with_quant);
// Adapted from sglang: INT4 W4A8 kernels
ops.def(
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
"Tensor scales) -> (Tensor, Tensor, Tensor)");
ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
&convert_weight_packed_scale_zp);
ops.def(
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
#endif #endif
// CPU attention kernels // CPU attention kernels

View File

@@ -3,8 +3,8 @@
#pragma once #pragma once
#include <torch/headeronly/util/BFloat16.h> #include <c10/util/BFloat16.h>
#include <torch/headeronly/util/Half.h> #include <c10/util/Half.h>
#include <cassert> #include <cassert>
#ifdef USE_ROCM #ifdef USE_ROCM

View File

@@ -6,15 +6,13 @@
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <torch/headeronly/util/shim_utils.h>
/** /**
* Helper function for checking CUTLASS errors * Helper function for checking CUTLASS errors
*/ */
#define CUTLASS_CHECK(status) \ #define CUTLASS_CHECK(status) \
{ \ { \
cutlass::Status error = status; \ cutlass::Status error = status; \
STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \ TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \ cutlassGetStatusString(error)); \
} }

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute { namespace cute {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////

View File

@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
} }
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) { for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
begin_loop(int epi_m, int epi_n) { begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt); copy(tSR_sRow_flt, tSR_rRow_flt);
} }
} }
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
auto [m, n, k, l] = args.tile_coord_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy)); using ThreadCount = decltype(size(args.tiled_copy));
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem), Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem //// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
Stride<_0, _1>>{}, Stride<_0, _1>>{},
Layout<_1>{}); Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow); Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow); Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord //// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow); Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg //// S2R: Smem to Reg
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>( return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow, tGS_gRow,
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE void CUTLASS_DEVICE void
begin() { begin() {
cute::Tensor pred = make_tensor<bool>(shape(tCgCol)); Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) { for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m; pred(i) = get<0>(tCcCol(i)) < m;
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE Array<Element, FragmentSize> CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col; Array<Element, FragmentSize> frg_col;
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) { for (int i = 0; i < FragmentSize; ++i) {
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
auto [M, N, K, L] = args.problem_shape_mnkl; auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl;
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and // Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate // partition the same way, this will be used to generate the predicate
// tensor for loading // tensor for loading
cute::Tensor cCol = make_identity_tensor(mCol.shape()); Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks( return ConsumerStoreCallbacks(

View File

@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
} }
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) { for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
begin_loop(int epi_m, int epi_n) { begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt); copy(tSR_sRow_flt, tSR_rRow_flt);
} }
} }
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
auto [m, n, k, l] = args.tile_coord_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy)); using ThreadCount = decltype(size(args.tiled_copy));
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem), Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem //// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
Stride<_0, _1>>{}, Stride<_0, _1>>{},
Layout<_1>{}); Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow); Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow); Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord //// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow); Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg //// S2R: Smem to Reg
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>( return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow, tGS_gRow,
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE void CUTLASS_DEVICE void
begin() { begin() {
cute::Tensor pred = make_tensor<bool>(shape(tCgCol)); Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) { for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m; pred(i) = get<0>(tCcCol(i)) < m;
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE Array<Element, FragmentSize> CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col; Array<Element, FragmentSize> frg_col;
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) { for (int i = 0; i < FragmentSize; ++i) {
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl; auto [M, N, K, L] = args.problem_shape_mnkl;
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and // Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate // partition the same way, this will be used to generate the predicate
// tensor for loading // tensor for loading
cute::Tensor cCol = make_identity_tensor(mCol.shape()); Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks( return ConsumerStoreCallbacks(

View File

@@ -1,7 +1,5 @@
#pragma once #pragma once
#include <torch/csrc/stable/tensor.h>
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/* /*
@@ -54,7 +52,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or // from a tensor. It can handle both row and column, as well as row/column or
// scalar cases. // scalar cases.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(torch::stable::Tensor const& tensor) { static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr()); auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> || if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
@@ -70,8 +68,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which // This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used. // case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor( static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
std::optional<torch::stable::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>); static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr; auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
@@ -120,8 +117,8 @@ struct ScaledEpilogue
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>; cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales, static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) { torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
@@ -163,9 +160,9 @@ struct ScaledEpilogueBias
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>; EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales, static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::Tensor const& b_scales,
torch::stable::Tensor const& bias) { torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -223,11 +220,10 @@ struct ScaledEpilogueBiasAzp
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args( static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::stable::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::stable::Tensor const& b_scales, torch::Tensor const& azp_adj,
torch::stable::Tensor const& azp_adj, std::optional<torch::Tensor> const& bias) {
std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -302,11 +298,11 @@ struct ScaledEpilogueBiasAzpToken
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args( static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::stable::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::stable::Tensor const& b_scales, torch::Tensor const& azp_adj,
torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp, torch::Tensor const& azp,
std::optional<torch::stable::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

View File

@@ -3,14 +3,6 @@
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
// (stable ABI) targets. When compiled under the stable ABI target,
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
// use torch::stable::Tensor instead.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#endif
/* /*
This file defines custom epilogues for fusing channel scales, token scales, This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the bias, and activation zero-points onto a GEMM operation using the
@@ -23,12 +15,6 @@
namespace vllm::c3x { namespace vllm::c3x {
#ifdef TORCH_TARGET_VERSION
using TensorType = torch::stable::Tensor;
#else
using TensorType = torch::Tensor;
#endif
using namespace cute; using namespace cute;
template <typename T> template <typename T>
@@ -98,7 +84,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or // from a tensor. It can handle both row and column, as well as row/column or
// scalar cases. // scalar cases.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(TensorType const& tensor) { static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr()); auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> || if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
@@ -114,7 +100,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which // This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used. // case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<TensorType> const& tensor) { static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr; auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> || static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
@@ -172,8 +158,8 @@ struct ScaledEpilogue
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales, static ArgumentType prepare_args(torch::Tensor const& a_scales,
TensorType const& b_scales) { torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
@@ -217,9 +203,9 @@ struct ScaledEpilogueBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales, static ArgumentType prepare_args(torch::Tensor const& a_scales,
TensorType const& b_scales, torch::Tensor const& b_scales,
TensorType const& bias) { torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -260,9 +246,9 @@ struct ScaledEpilogueColumnBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales, static ArgumentType prepare_args(torch::Tensor const& a_scales,
TensorType const& b_scales, torch::Tensor const& b_scales,
TensorType const& bias) { torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -318,10 +304,10 @@ struct ScaledEpilogueBiasAzp
EVTComputeScaleB, Bias>; EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales, static ArgumentType prepare_args(torch::Tensor const& a_scales,
TensorType const& b_scales, torch::Tensor const& b_scales,
TensorType const& azp_adj, torch::Tensor const& azp_adj,
std::optional<TensorType> const& bias) { std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -394,11 +380,11 @@ struct ScaledEpilogueBiasAzpToken
EVTComputeScaleB, Bias>; EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales, static ArgumentType prepare_args(torch::Tensor const& a_scales,
TensorType const& b_scales, torch::Tensor const& b_scales,
TensorType const& azp_adj, torch::Tensor const& azp_adj,
TensorType const& azp, torch::Tensor const& azp,
std::optional<TensorType> const& bias) { std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

View File

@@ -1,21 +1,6 @@
#pragma once #pragma once
// This header is shared between _C (unstable ABI, used by machete) and #include <torch/all.h>
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "cutlass/layout/matrix.h" #include "cutlass/layout/matrix.h"
@@ -70,16 +55,16 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra // If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1. // strides are set to be 0 or 1.
template <typename Stride> template <typename Stride>
static inline auto make_cute_layout(TorchTensor const& tensor, static inline auto make_cute_layout(torch::Tensor const& tensor,
std::string_view name = "tensor") { std::string_view name = "tensor") {
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{})); TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele, auto stride = cute::transform_with_idx(
auto const& idx) { Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>; using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) { if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) { if constexpr (cute::is_static_v<StrideEle>) {
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value); name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{}; return StrideEle{};
} else { } else {
@@ -112,7 +97,7 @@ static inline auto make_cute_layout(TorchTensor const& tensor,
template <typename Stride> template <typename Stride>
static inline auto maybe_make_cute_layout( static inline auto maybe_make_cute_layout(
std::optional<TorchTensor> const& tensor, std::optional<torch::Tensor> const& tensor,
std::string_view name = "tensor") { std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor)); using Layout = decltype(make_cute_layout<Stride>(*tensor));
@@ -136,12 +121,12 @@ template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type; using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <> template <>
struct equivalent_cutlass_type<torch::headeronly::Half> { struct equivalent_cutlass_type<c10::Half> {
using type = cutlass::half_t; using type = cutlass::half_t;
}; };
template <> template <>
struct equivalent_cutlass_type<torch::headeronly::BFloat16> { struct equivalent_cutlass_type<c10::BFloat16> {
using type = cutlass::bfloat16_t; using type = cutlass::bfloat16_t;
}; };
@@ -149,8 +134,8 @@ struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) // equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
// //
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e. // Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half` // c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
template <typename T> template <typename T>
struct equivalent_scalar_type { struct equivalent_scalar_type {
using type = T; using type = T;
@@ -161,15 +146,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <> template <>
struct equivalent_scalar_type<cutlass::half_t> { struct equivalent_scalar_type<cutlass::half_t> {
using type = torch::headeronly::Half; using type = c10::Half;
}; };
template <> template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> { struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = torch::headeronly::BFloat16; using type = c10::BFloat16;
}; };
// get equivalent torch::headeronly::ScalarType tag from compile time type // get equivalent c10::ScalarType tag from compile time type
template <typename T> template <typename T>
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v = static inline constexpr c10::ScalarType equivalent_scalar_type_v =
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value; c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;

View File

@@ -49,15 +49,6 @@
THO_DISPATCH_SWITCH(TYPE, NAME, \ THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
// Half types dispatch (Half + BFloat16)
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
// Boolean dispatch // Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \ #define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \ if (expr) { \

View File

@@ -27,111 +27,4 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_s, torch::stable::Tensor& output_s,
int64_t group_size, double eps, double int8_min, int64_t group_size, double eps, double int8_min,
double int8_max); double int8_max);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch);
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
// FP4/NVFP4 ops
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets);
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& output_block_scale,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
#endif #endif

View File

@@ -1,175 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
#endif
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 quantization kernel");
}
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::stable::empty(
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
torch::stable::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::stable::empty(
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
} else {
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
torch::headeronly::ScalarType::Byte,
std::nullopt, device);
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -1,87 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
const torch::stable::Tensor& A,
const torch::stable::Tensor& B,
const torch::stable::Tensor& A_sf,
const torch::stable::Tensor& B_sf,
const torch::stable::Tensor& alpha) {
// Make sure we're on A's device.
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) {
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
if (runtimeVersion < 12080) return false;
// Only report support when the SM-specific kernel was actually compiled in,
// so the Python-side backend selector does not choose CUTLASS and then hit
// TORCH_CHECK_NOT_IMPLEMENTED (or worse, fall through to Marlin).
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (cuda_device_capability >= 100 && cuda_device_capability < 120)
return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (cuda_device_capability >= 120 && cuda_device_capability < 130)
return true;
#endif
return false;
}

View File

@@ -1,22 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,22 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,23 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,52 +0,0 @@
#pragma once
#include <torch/csrc/stable/tensor.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
} // namespace vllm

View File

@@ -1,24 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm100_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,25 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,24 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,25 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,220 +0,0 @@
#include <stddef.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h"
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm75(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm80(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}

View File

@@ -1,38 +0,0 @@
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
*/
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
}
void cutlass_scaled_mm_azp_sm90(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
azp, bias);
}
#endif

View File

@@ -1,451 +0,0 @@
#include <cudaTypedefs.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data_caller(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} else if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif
return false;
}
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
// or CUDA 12.8 and SM100 (Blackwell)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12030;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
// Check for strides and alignment
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
if (bias) {
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
if (version_num >= 120) {
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if (version_num >= 100 && version_num < 120) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 75) {
// Turing
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100 && version_num < 110) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
if (version_num >= 90 && version_num < 100) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90 or 100");
}
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
blockscale_offsets, is_gated);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
"no cutlass_scaled_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_batched_moe_mm_data: no "
"cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if (bias) {
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
}
if (azp) {
STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
}
STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
// azp & bias types
STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(!azp ||
azp->scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
"currently bias dtype must match output dtype ",
c.scalar_type());
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
// Turing
STD_TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@@ -31,174 +31,6 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> " "output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()"); "()");
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
// Check if cutlass grouped gemm is supported for CUDA devices of the given
// capability
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
// CUTLASS w8a8 grouped GEMM
ops.def(
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM. It takes topk_ids as an input, and computes expert_offsets
// (token start indices of each expert). In addition to this, it computes
// problem sizes for each expert's multiplication used by the two mms called
// from fused MoE operation, and arrays with permutations required to shuffle
// and de-shuffle the input/output of the fused operation.
ops.def(
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets, "
" bool is_gated) -> ()");
// compute per-expert problem sizes from expert_first_token_offset
// produced by vLLM's moe_permute kernel
ops.def(
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
" Tensor expert_first_token_offset, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int n, int k, bool swap_ab) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM in batched expert format. It takes expert_num_tokens
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()");
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
// Out variant
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
// This registration is now migrated to stable ABI
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
// yet in torch/headeronly), the tag should be applied from Python
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
// with the .impl remaining in C++.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
#endif #endif
} }
@@ -214,45 +46,6 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&per_token_group_quant_8bit_packed)); TORCH_BOX(&per_token_group_quant_8bit_packed));
ops.impl("per_token_group_quant_int8", ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8)); TORCH_BOX(&per_token_group_quant_int8));
// CUTLASS scaled_mm ops
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
// FP4/NVFP4 ops
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
#endif
}
// These capability-check functions take only primitive args (no tensors), so
// there is no device to dispatch on. CompositeExplicitAutograd makes them
// available for all backends. This is the stable ABI equivalent of calling
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
#ifndef USE_ROCM
ops.impl("cutlass_scaled_mm_supports_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
ops.impl("cutlass_group_gemm_supported",
TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
ops.impl("cutlass_scaled_mm_supports_fp4",
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
#endif #endif
} }

View File

@@ -1,17 +1,10 @@
#pragma once #pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h> #include <torch/headeronly/util/shim_utils.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
// Utility to get the current CUDA stream for a given device using stable APIs. // Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches. // Returns a cudaStream_t for use in kernel launches.
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {

View File

@@ -21,7 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio; int dim_ngroups_ratio;
bool is_variable_B; bool is_variable_B;
bool is_variable_C; bool is_variable_C;
int64_t null_block_id; int64_t pad_slot_id;
bool delta_softplus; bool delta_softplus;
bool cache_enabled; bool cache_enabled;

View File

@@ -118,17 +118,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr); : reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index; const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
if (cache_indices == nullptr) { // cache_index == params.pad_slot_id is defined as padding, so we exit early
cache_index = batch_id; if (cache_index == params.pad_slot_id){
} else if (params.cache_enabled) {
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
} else {
cache_index = cache_indices[batch_id];
}
// Skip batch entries whose cache index maps to the null block (padding).
if (cache_indices != nullptr && cache_index == params.null_block_id){
return; return;
} }
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
@@ -535,7 +527,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const std::optional<at::Tensor>& cache_indices, const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state, const std::optional<at::Tensor>& has_initial_state,
bool varlen, bool varlen,
int64_t null_block_id, int64_t pad_slot_id,
int64_t block_size, int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token, const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token, const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
@@ -552,7 +544,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.dstate = dstate; params.dstate = dstate;
params.n_groups = n_groups; params.n_groups = n_groups;
params.dim_ngroups_ratio = dim / n_groups; params.dim_ngroups_ratio = dim / n_groups;
params.null_block_id = null_block_id; params.pad_slot_id = pad_slot_id;
params.delta_softplus = delta_softplus; params.delta_softplus = delta_softplus;
@@ -666,7 +658,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &ssm_states, const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided // used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early // in case of padding, the kernel will return early
int64_t null_block_id, int64_t pad_slot_id,
int64_t block_size, int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token, const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token, const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
@@ -813,7 +805,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices, cache_indices,
has_initial_state, has_initial_state,
varlen, varlen,
null_block_id, pad_slot_id,
block_size, block_size,
block_idx_first_scheduled_token, block_idx_first_scheduled_token,
block_idx_last_scheduled_token, block_idx_last_scheduled_token,

View File

@@ -0,0 +1,144 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "gpt_oss_router_gemm.cuh"
void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB,
__nv_bfloat16* gC, __nv_bfloat16* bias,
int batch_size, int output_features,
int input_features, cudaStream_t stream) {
static int const WARP_TILE_M = 16;
static int const TILE_M = WARP_TILE_M;
static int const TILE_N = 8;
static int const TILE_K = 64;
static int const STAGES = 16;
static int const STAGE_UNROLL = 4;
static bool const PROFILE = false;
CUtensorMap weight_map{};
CUtensorMap activation_map{};
constexpr uint32_t rank = 2;
uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features};
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
uint32_t box_size[rank] = {TILE_K, TILE_M};
uint32_t elem_stride[rank] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(
&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank,
gB, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for weight_map, error code=",
static_cast<int>(res));
size[1] = batch_size;
box_size[1] = TILE_N;
res = cuTensorMapEncodeTiled(
&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
rank, gA, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for activation_map, error code=",
static_cast<int>(res));
int smem_size = STAGES * STAGE_UNROLL *
(TILE_M * TILE_K * sizeof(__nv_bfloat16) +
TILE_N * TILE_K * sizeof(__nv_bfloat16));
gpuErrChk(cudaFuncSetAttribute(
gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
dim3 grid(tiles_m, tiles_n);
dim3 block(384);
cudaLaunchConfig_t config;
cudaLaunchAttribute attrs[1];
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
config.attrs = attrs;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
cudaLaunchKernelEx(
&config,
&gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map,
activation_map, nullptr);
}
void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output,
torch::Tensor input, torch::Tensor weight,
torch::Tensor bias) {
auto const batch_size = input.size(0);
auto const input_dim = input.size(1);
auto const output_dim = weight.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::BFloat16) {
launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(),
(__nv_bfloat16*)weight.data_ptr(),
(__nv_bfloat16*)output.mutable_data_ptr(),
(__nv_bfloat16*)bias.data_ptr(), batch_size,
output_dim, input_dim, stream);
} else {
throw std::invalid_argument("Unsupported dtype, only supports bfloat16");
}
}
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias) {
TORCH_CHECK(input.dim() == 2, "input must be 2D");
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");
TORCH_CHECK(input.sizes()[1] == weight.sizes()[1],
"input.size(1) must match weight.size(1)");
TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0],
"weight.size(0) must match bias.size(0)");
TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16,
"input tensor must be bfloat16");
TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16,
"weight tensor must be bfloat16");
TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16,
"bias tensor must be bfloat16");
gpt_oss_router_gemm_cuda_forward(output, input, weight, bias);
}

View File

@@ -0,0 +1,447 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_bf16.h"
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "cuda_pipeline.h"
#include <cuda.h>
#include <cuda/barrier>
#include <cuda/std/utility>
#include <cuda_runtime.h>
using barrier = cuda::barrier<cuda::thread_scope_block>;
namespace cde = cuda::device::experimental;
namespace ptx = cuda::ptx;
#define gpuErrChk(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
}
inline void gpuAssert(cudaError_t code, char const* file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort) {
throw std::runtime_error(cudaGetErrorString(code));
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
__device__ uint64_t gclock64() {
unsigned long long int rv;
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv));
return rv;
}
__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) {
int dst;
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = dst;
}
__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) {
int x, y;
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(x), "=r"(y)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = x;
rvi[1] = y;
}
__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) {
int x, y, z, w;
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = x;
rvi[1] = y;
rvi[2] = z;
rvi[3] = w;
}
__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2],
float c[4]) {
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
float const* C = reinterpret_cast<float const*>(&c[0]);
float* D = reinterpret_cast<float*>(&d[0]);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]),
"f"(C[3]));
}
__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4],
float c[4]) {
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
float const* C = reinterpret_cast<float const*>(&c[0]);
float* D = reinterpret_cast<float*>(&d[0]);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
}
__device__ void bar_wait(uint32_t bar_ptr, int phase) {
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(bar_ptr),
"r"(phase));
}
__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) {
uint32_t success;
#ifdef INTERNAL
asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory");
#endif
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(success)
: "r"(bar_ptr), "r"(phase));
return success;
}
__device__ uint32_t elect_one_sync() {
uint32_t pred = 0;
uint32_t laneid = 0;
asm volatile(
"{\n"
".reg .b32 %%rx;\n"
".reg .pred %%px;\n"
" elect.sync %%rx|%%px, %2;\n"
"@%%px mov.s32 %1, 1;\n"
" mov.s32 %0, %%rx;\n"
"}\n"
: "+r"(laneid), "+r"(pred)
: "r"(0xFFFFFFFF));
return pred;
}
#endif
struct Profile {
uint64_t start;
uint64_t weight_load_start;
uint64_t act_load_start;
uint64_t compute_start;
uint64_t complete;
};
template <int WARP_TILE_M, int TILE_M, int TILE_N, int TILE_K, int STAGES,
int STAGE_UNROLL, bool PROFILE>
__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel(
__nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations,
__nv_bfloat16* bias, int M, int N, int K,
const __grid_constant__ CUtensorMap weight_map,
const __grid_constant__ CUtensorMap activation_map,
Profile* profile = nullptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0)
profile[blockIdx.x].start = gclock64();
extern __shared__ __align__(128) char smem[];
__nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0];
__nv_bfloat16* sh_activations =
(__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K *
sizeof(__nv_bfloat16)];
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ barrier bar_wt_ready[STAGES];
__shared__ barrier bar_act_ready[STAGES];
__shared__ barrier bar_data_consumed[STAGES];
__shared__ float4 reduction_buffer[128];
__shared__ nv_bfloat16 sh_bias[TILE_M];
if (threadIdx.x == 0) {
for (int i = 0; i < STAGES; i++) {
init(&bar_wt_ready[i], 1);
init(&bar_act_ready[i], 1);
init(&bar_data_consumed[i], 32);
}
ptx::fence_proxy_async(ptx::space_shared);
asm volatile("prefetch.tensormap [%0];"
:
: "l"(reinterpret_cast<uint64_t>(&weight_map))
: "memory");
asm volatile("prefetch.tensormap [%0];"
:
: "l"(reinterpret_cast<uint64_t>(&activation_map))
: "memory");
}
__syncthreads();
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
int phase = 0;
int mib = blockIdx.x * TILE_M;
int ni = blockIdx.y * TILE_N;
float accum[4];
for (int i = 0; i < 4; i++) accum[i] = 0.f;
int const K_LOOPS_DMA =
(K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL));
int const K_LOOPS_COMPUTE = K_LOOPS_DMA;
// Data loading thread
if (warp_id >= 4 && elect_one_sync()) {
int stage = warp_id % 4;
bool weight_warp = warp_id < 8;
if (!weight_warp) {
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
for (int ki = 0; ki < K_LOOPS_DMA; ki++) {
int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL;
uint64_t desc_ptr_wt = reinterpret_cast<uint64_t>(&weight_map);
uint64_t desc_ptr_act = reinterpret_cast<uint64_t>(&activation_map);
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16);
int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16);
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
if (weight_warp)
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt));
if (!weight_warp)
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act));
if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp)
profile[blockIdx.x].weight_load_start = gclock64();
if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp)
profile[blockIdx.x].act_load_start = gclock64();
for (int i = 0; i < STAGE_UNROLL; i++) {
uint32_t smem_ptr_wt = __cvta_generic_to_shared(
&sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]);
uint32_t crd0 = k + i * TILE_K;
uint32_t crd1 = mib;
if (weight_warp)
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
: "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0),
"r"(crd1)
: "memory");
uint32_t smem_ptr_act = __cvta_generic_to_shared(
&sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]);
crd0 = k + i * TILE_K;
crd1 = ni;
if (!weight_warp)
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
: "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act),
"r"(crd0), "r"(crd1)
: "memory");
}
stage += 4;
if (stage >= STAGES) {
stage = warp_id % 4;
phase ^= 1;
}
}
// Wait for pending loads to be consumed before exiting, to avoid race
for (int i = 0; i < (STAGES / 4) - 1; i++) {
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
stage += 4;
if (stage >= STAGES) {
stage = warp_id % 4;
phase ^= 1;
}
}
}
// Compute threads
else if (warp_id < 4) {
// Sneak the bias load into the compute warps since they're just waiting for
// stuff anyway
if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x];
int stage = warp_id;
int phase = 0;
int lane_id_div8 = lane_id / 8;
int lane_id_mod8 = lane_id % 8;
int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0;
int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0;
int row_wt = lane_id_mod8 + lane_row_offset_wt;
int row_act = lane_id_mod8;
int row_offset_wt = (reinterpret_cast<uintptr_t>(sh_weights) / 128) % 8;
int row_offset_act = row_offset_wt;
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
bool weight_ready = bar_try_wait(bar_ptr_wt, phase);
bool act_ready = bar_try_wait(bar_ptr_act, phase);
#pragma unroll 2
for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) {
int next_stage = stage + 4;
int next_phase = phase;
if (next_stage >= STAGES) {
next_stage = warp_id;
next_phase ^= 1;
}
while (!weight_ready || !act_ready) {
weight_ready = bar_try_wait(bar_ptr_wt, phase);
act_ready = bar_try_wait(bar_ptr_act, phase);
}
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0)
profile[blockIdx.x].compute_start = gclock64();
if (ki + 1 < K_LOOPS_COMPUTE) {
weight_ready = bar_try_wait(
__cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase);
act_ready = bar_try_wait(
__cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase);
}
#pragma unroll
for (int su = 0; su < STAGE_UNROLL; su++) {
__nv_bfloat16* ptr_weights =
&sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K];
__nv_bfloat16* ptr_act =
&sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K];
#pragma unroll
for (int kii = 0; kii < TILE_K / 16; kii++) {
__nv_bfloat16 a[8];
__nv_bfloat16 b[4];
int col = 2 * kii + lane_col_offset_wt;
int col_sw = ((row_wt + row_offset_wt) % 8) ^ col;
ldmatrix4(a, __cvta_generic_to_shared(
&ptr_weights[row_wt * TILE_K + col_sw * 8]));
col = 2 * kii + lane_id_div8;
col_sw = ((row_act + row_offset_act) % 8) ^ col;
ldmatrix2(b, __cvta_generic_to_shared(
&ptr_act[row_act * TILE_K + 8 * col_sw]));
HMMA_16816(accum, a, b, accum);
}
}
uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]);
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c));
stage = next_stage;
phase = next_phase;
}
float4 accum4;
accum4.x = accum[0];
accum4.y = accum[1];
accum4.z = accum[2];
accum4.w = accum[3];
reduction_buffer[threadIdx.x] = accum4;
__syncthreads();
if (warp_id == 0) {
int mi = mib + warp_id * WARP_TILE_M;
int tm = mi + lane_id / 4;
int tn = ni + 2 * (lane_id % 4);
float4 accum1 = reduction_buffer[32 + threadIdx.x];
float4 accum2 = reduction_buffer[64 + threadIdx.x];
float4 accum3 = reduction_buffer[96 + threadIdx.x];
accum[0] = accum[0] + accum1.x + accum2.x + accum3.x;
accum[1] = accum[1] + accum1.y + accum2.y + accum3.y;
accum[2] = accum[2] + accum1.z + accum2.z + accum3.z;
accum[3] = accum[3] + accum1.w + accum2.w + accum3.w;
float bias_lo = __bfloat162float(sh_bias[tm - mib]);
float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]);
if (tn < N && tm < M)
output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo);
if (tn + 1 < N && tm < M)
output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo);
if (tn < N && tm + 8 < M)
output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi);
if (tn + 1 < N && tm + 8 < M)
output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi);
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
profile[blockIdx.x].complete = gclock64();
}
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

View File

@@ -108,15 +108,6 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS, "thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2], "group_blocks": [2],
}, },
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation # AWQ-INT4 with INT8 activation
{ {
"a_type": ["kS8"], "a_type": ["kS8"],

View File

@@ -343,8 +343,6 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) { if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2); s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) {
static_assert(group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16); static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) { } else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -359,10 +357,9 @@ __global__ void Marlin(
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
b_type == vllm::kS4 || b_type == vllm::kS8 || b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128; b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
// see comments of dequant.h for more details // see comments of dequant.h for more details
constexpr bool dequant_skip_flop = constexpr bool dequant_skip_flop =
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) || is_a_8bit || b_type == vllm::kFE4M3fn ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value || has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8); has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -376,7 +373,7 @@ __global__ void Marlin(
const int group_size = const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = const int scales_expert_stride =
prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8); prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride = const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8 is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4); : prob_n * prob_k / group_size / (pack_factor * 4);
@@ -695,8 +692,9 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order // Scale sizes/strides without act_order
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8); int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8); constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups = constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks ? thread_k_blocks / group_blocks
@@ -1133,7 +1131,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (!is_8bit_scale) { if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else { } else {
@@ -1142,7 +1140,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} }
} else if (group_blocks >= b_sh_wr_iters) { } else if (group_blocks >= b_sh_wr_iters) {
if constexpr (!is_8bit_scale) { if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[1])[0] = reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0]; reinterpret_cast<int4*>(&frag_s[0])[0];
} else { } else {
@@ -1343,7 +1341,7 @@ __global__ void Marlin(
} }
} }
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) { if constexpr (b_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0]; int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1]; int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -599,9 +599,6 @@ torch::Tensor moe_wna16_marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be", "When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
} }
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} }
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 // Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a,
const torch::Tensor& mat_b); const torch::Tensor& mat_b);
// gpt-oss optimized router GEMM kernel for SM90+
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias);
#endif #endif

View File

@@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// DeepSeek V3 optimized router GEMM for SM90+ // DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
// gpt-oss optimized router GEMM kernel for SM90+
m.def(
"gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, "
"Tensor bias) -> ()");
m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm);
#endif #endif
} }

View File

@@ -53,12 +53,12 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
void merge_attn_states( void merge_attn_states(torch::Tensor& output,
torch::Tensor& output, std::optional<torch::Tensor> output_lse, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& prefix_output,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, const torch::Tensor& prefix_lse,
const std::optional<int64_t> prefill_tokens_with_context, const torch::Tensor& suffix_output,
const std::optional<torch::Tensor>& output_scale = std::nullopt); const torch::Tensor& suffix_lse);
#ifndef USE_ROCM #ifndef USE_ROCM
void convert_vertical_slash_indexes( void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
@@ -143,14 +143,6 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
std::optional<torch::Tensor> residual, std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed); int64_t group_size, bool is_scale_transposed);
#ifndef USE_ROCM
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed);
#endif
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
@@ -160,6 +152,12 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
#ifndef USE_ROCM
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& output_block_scale,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void persistent_masked_m_silu_mul_quant( void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H) const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E) const at::Tensor& counts, // (E)
@@ -227,6 +225,89 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
int64_t ggml_moe_get_block_size(int64_t type); int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_moe_mm(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale, torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp); std::optional<torch::Tensor> const& azp);
@@ -262,7 +343,7 @@ void selective_scan_fwd(
const std::optional<torch::Tensor>& query_start_loc, const std::optional<torch::Tensor>& query_start_loc,
const std::optional<torch::Tensor>& cache_indices, const std::optional<torch::Tensor>& cache_indices,
const std::optional<torch::Tensor>& has_initial_state, const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size, const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token, const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token, const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx, const std::optional<torch::Tensor>& initial_state_idx,

View File

@@ -2,9 +2,10 @@
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <torch/csrc/stable/tensor.h> #include <torch/all.h>
#include "libtorch_stable/torch_utils.h" #include <c10/cuda/CUDAStream.h>
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h" #include "cutlass/bfloat16.h"
#include "cutlass/float8.h" #include "cutlass/float8.h"
@@ -40,7 +41,7 @@ __global__ void get_group_gemm_starts(
} }
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ #define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \ else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \ get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \ cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \ <<<1, num_experts, 0, stream>>>( \
@@ -65,34 +66,23 @@ __global__ void get_group_gemm_starts(
namespace { namespace {
void run_get_group_gemm_starts( void run_get_group_gemm_starts(
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs, torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs, torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs, torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::stable::Tensor& b_group_scales_ptrs, torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
torch::stable::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::stable::Tensor const& a_scales, torch::Tensor const& b_group_scales, const int64_t b_group_size) {
torch::stable::Tensor const& b_scales, TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) { TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
STD_TORCH_CHECK(a_tensors.scalar_type() == TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
torch::headeronly::ScalarType::Float8_e4m3fn); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
STD_TORCH_CHECK( TORCH_CHECK(b_group_scales.dtype() ==
b_tensors.scalar_type() == torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
torch::headeronly::ScalarType::Int); // int4 8x packed into int32 TORCH_CHECK(out_tensors.dtype() ==
STD_TORCH_CHECK(a_scales.scalar_type() == torch::kBFloat16); // only support bf16 for now
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(
b_group_scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
// type is e4m3
STD_TORCH_CHECK(
out_tensors.scalar_type() ==
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
// expect int64_t to avoid overflow during offset calculations // expect int64_t to avoid overflow during offset calculations
STD_TORCH_CHECK(expert_offsets.scalar_type() == TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
torch::headeronly::ScalarType::Long);
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
// logical k, n // logical k, n
@@ -100,15 +90,14 @@ void run_get_group_gemm_starts(
int64_t k = a_tensors.size(1); int64_t k = a_tensors.size(1);
int64_t scale_k = cutlass::ceil_div(k, b_group_size); int64_t scale_k = cutlass::ceil_div(k, b_group_size);
auto stream = get_current_cuda_stream(a_tensors.get_device_index()); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
if (false) { if (false) {
} }
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16, __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
cutlass::bfloat16_t) __CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
else { else {
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
} }
} }

View File

@@ -14,12 +14,13 @@
#include "cutlass/util/mixed_dtype_utils.hpp" #include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes // vllm includes
#include <torch/csrc/stable/library.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/csrc/stable/tensor.h> #include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h" #include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
#include "core/registration.h"
#include "get_group_starts.cuh" #include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "w4a8_utils.cuh" #include "w4a8_utils.cuh"
@@ -167,40 +168,31 @@ struct W4A8GroupedGemmKernel {
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0, static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
"LayoutB_Reordered size must be divisible by 4 bytes"); "LayoutB_Reordered size must be divisible by 4 bytes");
static void grouped_mm(torch::stable::Tensor& out_tensors, static void grouped_mm(
const torch::stable::Tensor& a_tensors, torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors, const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::stable::Tensor& a_scales, const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const torch::stable::Tensor& b_scales, const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::stable::Tensor& b_group_scales, const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
const int64_t b_group_size, const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::stable::Tensor& expert_offsets, const torch::Tensor& group_scale_strides) {
const torch::stable::Tensor& problem_sizes_torch,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides) {
auto device = a_tensors.device(); auto device = a_tensors.device();
auto device_id = device.index(); auto device_id = device.index();
const torch::stable::accelerator::DeviceGuard device_guard(device_id); const at::cuda::OptionalCUDAGuard device_guard(device);
auto stream = get_current_cuda_stream(device_id); auto stream = at::cuda::getCurrentCUDAStream(device_id);
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
int n = static_cast<int>(b_tensors.size(1)); int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
torch::stable::Tensor a_ptrs = torch::stable::empty( auto options_int =
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); torch::TensorOptions().dtype(torch::kInt64).device(device);
torch::stable::Tensor b_ptrs = torch::stable::empty( torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor out_ptrs = torch::stable::empty( torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor a_scales_ptrs = torch::stable::empty( torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
// get the correct offsets to pass to gemm // get the correct offsets to pass to gemm
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
@@ -255,9 +247,9 @@ struct W4A8GroupedGemmKernel {
// Allocate workspace // Allocate workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments); size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::stable::Tensor workspace = torch::stable::empty( torch::Tensor workspace =
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt, torch::empty(workspace_size,
device); torch::TensorOptions().dtype(torch::kU8).device(device));
// Run GEMM // Run GEMM
GemmShuffled gemm; GemmShuffled gemm;
@@ -302,20 +294,14 @@ using Kernel_256x128_2x1x1_Coop =
using Kernel_128x256_2x1x1_Coop = using Kernel_128x256_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>; W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
void mm_dispatch(torch::stable::Tensor& out_tensors, void mm_dispatch(
const torch::stable::Tensor& a_tensors, torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors, const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::stable::Tensor& a_scales, const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const torch::stable::Tensor& b_scales, const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::stable::Tensor& b_group_scales, const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const int64_t b_group_size, const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::stable::Tensor& expert_offsets, const torch::Tensor& group_scale_strides, const std::string& schedule) {
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
const std::string& schedule) {
if (schedule == "Kernel_128x16_1x1x1_Coop") { if (schedule == "Kernel_128x16_1x1x1_Coop") {
Kernel_128x16_1x1x1_Coop::grouped_mm( Kernel_128x16_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
@@ -372,23 +358,18 @@ void mm_dispatch(torch::stable::Tensor& out_tensors,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides); c_strides, group_scale_strides);
} else { } else {
STD_TORCH_CHECK(false, TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule); "cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
} }
} }
void mm(torch::stable::Tensor& out_tensors, void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::stable::Tensor& a_tensors, const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::stable::Tensor& b_tensors, const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const torch::stable::Tensor& a_scales, const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::stable::Tensor& b_scales, const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::stable::Tensor& b_group_scales, const int64_t b_group_size, const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::stable::Tensor& expert_offsets, const torch::Tensor& group_scale_strides,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
std::optional<std::string> maybe_schedule) { std::optional<std::string> maybe_schedule) {
// user has specified a schedule // user has specified a schedule
if (maybe_schedule) { if (maybe_schedule) {
@@ -425,27 +406,26 @@ void mm(torch::stable::Tensor& out_tensors,
a_strides, b_strides, c_strides, group_scale_strides, schedule); a_strides, b_strides, c_strides, group_scale_strides, schedule);
} }
std::tuple<torch::stable::Tensor, torch::stable::Tensor> std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) { torch::Tensor const& b_tensors) {
STD_TORCH_CHECK(b_tensors.scalar_type() == TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
torch::headeronly::ScalarType::Int); TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k) TORCH_CHECK(b_tensors.is_contiguous());
STD_TORCH_CHECK(b_tensors.is_contiguous()); TORCH_CHECK(b_tensors.is_cuda());
STD_TORCH_CHECK(b_tensors.is_cuda());
int n = static_cast<int>(b_tensors.size(1)); int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0. // CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
// These misalignments cause silent OOB unless run under Compute Sanitizer. // These misalignments cause silent OOB unless run under Compute Sanitizer.
STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256"); TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16"); TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
// we will store the layout to an int32 tensor; // we will store the layout to an int32 tensor;
// this is the number of elements we need per layout // this is the number of elements we need per layout
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t); constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors); torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
int num_experts = static_cast<int>(b_tensors.size(0)); int num_experts = static_cast<int>(b_tensors.size(0));
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr()); auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
@@ -455,7 +435,7 @@ encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
size_t num_int4_elems = 1ull * num_experts * n * k; size_t num_int4_elems = 1ull * num_experts * n * k;
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr, bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
num_int4_elems); num_int4_elems);
STD_TORCH_CHECK(ok, "unified_encode_int4b failed"); TORCH_CHECK(ok, "unified_encode_int4b failed");
// construct the layout once; assumes each expert has the same layout // construct the layout once; assumes each expert has the same layout
using LayoutType = LayoutB_Reordered; using LayoutType = LayoutB_Reordered;
@@ -476,28 +456,28 @@ encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
} }
// save the packed layout to torch tensor so we can re-use it // save the packed layout to torch tensor so we can re-use it
torch::stable::Tensor layout_cpu = torch::stable::empty( auto cpu_opts =
{num_experts, layout_width}, torch::headeronly::ScalarType::Int, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU)); torch::Tensor layout_cpu =
torch::empty({num_experts, layout_width}, cpu_opts);
int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>(); int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
std::memcpy(layout_data + i * layout_width, // dst (int32*) std::memcpy(layout_data + i * layout_width, // dst (int32*)
&layout_B_reordered, // src (LayoutType*) &layout_B_reordered, // src (LayoutType*)
sizeof(LayoutType)); // number of bytes sizeof(LayoutType)); // number of bytes
} }
torch::stable::Tensor packed_layout = torch::Tensor packed_layout =
torch::stable::to(layout_cpu, b_tensors.device(), layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
/*non_blocking=*/false);
return {b_tensors_packed, packed_layout}; return {b_tensors_packed, packed_layout};
} }
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm)); m.impl("cutlass_w4a8_moe_mm", &mm);
m.impl("cutlass_encode_and_reorder_int4b_grouped", m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
TORCH_BOX(&encode_and_reorder_int4b));
} }
} // namespace vllm::cutlass_w4a8_moe } // namespace vllm::cutlass_w4a8_moe
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -3,12 +3,14 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu // https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
// //
#include <torch/csrc/stable/library.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/csrc/stable/tensor.h> #include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h" #include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh" #include "w4a8_utils.cuh"
#include "core/registration.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include <limits> #include <limits>
@@ -159,31 +161,31 @@ struct W4A8GemmKernel {
using StrideD = typename GemmKernelShuffled::StrideD; using StrideD = typename GemmKernelShuffled::StrideD;
using StrideS = typename CollectiveMainloopShuffled::StrideScale; using StrideS = typename CollectiveMainloopShuffled::StrideScale;
static torch::stable::Tensor mm( static torch::Tensor mm(torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B, // already packed
torch::stable::Tensor const& B, // already packed torch::Tensor const& group_scales, // already packed
torch::stable::Tensor const& group_scales, // already packed int64_t group_size,
int64_t group_size, torch::stable::Tensor const& channel_scales, torch::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales, torch::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type) { std::optional<at::ScalarType> const& maybe_out_type) {
// TODO: param validation // TODO: param validation
int m = A.size(0); int m = A.size(0);
int k = A.size(1); int k = A.size(1);
int n = B.size(1); int n = B.size(1);
// safely cast group_size to int // safely cast group_size to int
STD_TORCH_CHECK( TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size); "group_size out of supported range for int: ", group_size);
int const group_size_int = static_cast<int>(group_size); int const group_size_int = static_cast<int>(group_size);
// Allocate output // Allocate output
const torch::stable::accelerator::DeviceGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
A.get_device_index());
auto device = A.device(); auto device = A.device();
auto stream = get_current_cuda_stream(device.index()); auto stream = at::cuda::getCurrentCUDAStream(device.index());
torch::stable::Tensor D = torch::stable::empty( torch::Tensor D =
{m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device); torch::empty({m, n}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<ElementD>)
.device(device));
// prepare arg pointers // prepare arg pointers
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr()); auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr()); auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
@@ -235,9 +237,9 @@ struct W4A8GemmKernel {
// Workspace // Workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments); size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::stable::Tensor workspace = torch::stable::empty( torch::Tensor workspace =
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt, torch::empty(workspace_size,
device); torch::TensorOptions().dtype(torch::kU8).device(device));
// Run GEMM // Run GEMM
GemmShuffled gemm; GemmShuffled gemm;
@@ -267,13 +269,13 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>; using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>; using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
torch::stable::Tensor mm_dispatch( torch::Tensor mm_dispatch(torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B, // already packed
torch::stable::Tensor const& B, // already packed torch::Tensor const& group_scales, // already packed
torch::stable::Tensor const& group_scales, // already packed int64_t group_size,
int64_t group_size, torch::stable::Tensor const& channel_scales, torch::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales, torch::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type, std::optional<at::ScalarType> const& maybe_out_type,
const std::string& schedule) { const std::string& schedule) {
if (schedule == "256x128_1x1x1") { if (schedule == "256x128_1x1x1") {
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
@@ -316,17 +318,16 @@ torch::stable::Tensor mm_dispatch(
channel_scales, token_scales, channel_scales, token_scales,
maybe_out_type); maybe_out_type);
} }
STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
return {}; return {};
} }
torch::stable::Tensor mm( torch::Tensor mm(torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B, // already packed
torch::stable::Tensor const& B, // already packed torch::Tensor const& group_scales, // already packed
torch::stable::Tensor const& group_scales, // already packed int64_t group_size, torch::Tensor const& channel_scales,
int64_t group_size, torch::stable::Tensor const& channel_scales, torch::Tensor const& token_scales,
torch::stable::Tensor const& token_scales, std::optional<at::ScalarType> const& maybe_out_type,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) { std::optional<std::string> maybe_schedule) {
// requested a specific schedule // requested a specific schedule
if (maybe_schedule) { if (maybe_schedule) {
@@ -377,15 +378,14 @@ torch::stable::Tensor mm(
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Pre-processing utils // Pre-processing utils
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) { torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
STD_TORCH_CHECK(scales.scalar_type() == TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
torch::headeronly::ScalarType::Float8_e4m3fn); TORCH_CHECK(scales.is_contiguous());
STD_TORCH_CHECK(scales.is_contiguous()); TORCH_CHECK(scales.is_cuda());
STD_TORCH_CHECK(scales.is_cuda());
auto packed_scales = auto packed_scales = torch::empty(
torch::stable::empty({scales.numel() * ScalePackSize}, {scales.numel() * ScalePackSize},
scales.scalar_type(), std::nullopt, scales.device()); torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr()); auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
auto packed_scales_ptr = auto packed_scales_ptr =
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>( static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
@@ -396,16 +396,15 @@ torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
return packed_scales; return packed_scales;
} }
torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) { torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int); TORCH_CHECK(B.dtype() == torch::kInt32);
STD_TORCH_CHECK(B.dim() == 2); TORCH_CHECK(B.dim() == 2);
torch::stable::Tensor B_packed = torch::stable::empty_like(B); torch::Tensor B_packed = torch::empty_like(B);
int k = B.size(0) * PackFactor; // logical k int k = B.size(0) * PackFactor; // logical k
int n = B.size(1); int n = B.size(1);
STD_TORCH_CHECK((n * k) % 32 == 0, TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
"need multiples of 32 int4s for 16B chunks");
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr()); auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr()); auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
@@ -416,17 +415,16 @@ torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr, bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
n * k); n * k);
STD_TORCH_CHECK(ok, "unified_encode_int4b failed"); TORCH_CHECK(ok, "unified_encode_int4b failed");
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
return B_packed; return B_packed;
} }
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm)); m.impl("cutlass_w4a8_mm", &mm);
m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8)); m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
m.impl("cutlass_encode_and_reorder_int4b", m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
TORCH_BOX(&encode_and_reorder_int4b));
} }
} // namespace vllm::cutlass_w4a8 } // namespace vllm::cutlass_w4a8

View File

@@ -14,15 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/csrc/stable/tensor.h> #include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "launch_bounds_utils.h" #include "launch_bounds_utils.h"
@@ -117,18 +118,16 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm } // namespace vllm
void silu_and_mul_nvfp4_quant_sm1xxa( void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
torch::stable::Tensor& output, // [..., d] torch::Tensor& output_sf,
torch::stable::Tensor& output_sf, torch::Tensor& input, // [..., 2 * d]
torch::stable::Tensor& input, // [..., 2 * d] torch::Tensor& input_sf) {
torch::stable::Tensor& input_sf) {
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1) / 2; int32_t n = input.size(1) / 2;
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK( TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16,
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4."); "Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount = int multiProcessorCount =
@@ -137,9 +136,8 @@ void silu_and_mul_nvfp4_quant_sm1xxa(
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr()); auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr()); auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr()); auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const torch::stable::accelerator::DeviceGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
input.get_device_index()); auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
auto stream = get_current_cuda_stream(input.get_device_index());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM = int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x)); vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
@@ -151,7 +149,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
VLLM_STABLE_DISPATCH_HALF_TYPES( VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());

View File

@@ -14,12 +14,14 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/csrc/stable/library.h> #include "core/registration.h"
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h> #include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
@@ -120,7 +122,7 @@ __global__ void __get_group_gemm_starts(
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \ #define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \ TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \ LayoutSFB, ScaleConfig) \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \ else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \ __get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \ LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \ <<<1, num_experts, 0, stream>>>( \
@@ -148,36 +150,26 @@ __global__ void __get_group_gemm_starts(
} }
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig> template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts, void run_get_group_gemm_starts(
const torch::stable::Tensor& b_starts, const torch::Tensor& a_starts, const torch::Tensor& b_starts,
const torch::stable::Tensor& out_starts, const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::stable::Tensor& a_scales_starts, const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::stable::Tensor& b_scales_starts, const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::stable::Tensor& alpha_starts, const torch::Tensor& a_strides, const torch::Tensor& b_strides,
const torch::stable::Tensor& layout_sfa, const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val, int64_t c_stride_val,
/*these are used for their base addresses*/ /*these are used for their base addresses*/
torch::stable::Tensor const& a_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::stable::Tensor const& b_tensors, torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
torch::stable::Tensor const& out_tensors, torch::Tensor const& b_scales, torch::Tensor const& alphas,
torch::stable::Tensor const& a_scales, torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
torch::stable::Tensor const& b_scales, torch::Tensor const& problem_sizes, int M, int N, int K) {
torch::stable::Tensor const& alphas,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& sf_offsets,
torch::stable::Tensor const& problem_sizes,
int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0); int num_experts = (int)expert_offsets.size(0);
auto stream = get_current_cuda_stream(a_tensors.get_device_index()); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
STD_TORCH_CHECK(out_tensors.size(1) == N, TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape"); "Output tensor shape doesn't match expected shape");
STD_TORCH_CHECK(K / 2 == b_tensors.size(2), TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing" "b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match"); " dimension must match");
if (false) { if (false) {
@@ -185,27 +177,23 @@ void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig) // ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE( __CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA, cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t, __CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, torch::kFloat16,
torch::headeronly::ScalarType::Half, half, half, LayoutSFA, LayoutSFB, ScaleConfig)
LayoutSFA, LayoutSFB, ScaleConfig)
else { else {
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
} }
} }
template <typename OutType> template <typename OutType>
void run_fp4_blockwise_scaled_group_mm_sm100( void run_fp4_blockwise_scaled_group_mm_sm100(
torch::stable::Tensor& output, const torch::stable::Tensor& a, torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::stable::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::stable::Tensor& alphas, const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
const torch::stable::Tensor& problem_sizes, int N, int K) {
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape = using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>; cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t; using ElementType = cutlass::float_e2m1_t;
@@ -284,40 +272,20 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::stable::Tensor a_ptrs = torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
std::nullopt, a.device()); torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor b_ptrs = torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
std::nullopt, a.device()); torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor out_ptrs = torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
std::nullopt, a.device()); torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::stable::Tensor a_scales_ptrs = torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>( run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -340,7 +308,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler; typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM; scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device_index(); hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts; static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] = cached_sm_counts[hw_info.device_id] =
@@ -382,35 +350,32 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
scheduler}; scheduler};
size_t workspace_size = Gemm::get_workspace_size(args); size_t workspace_size = Gemm::get_workspace_size(args);
auto workspace = auto const workspace_options =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
std::nullopt, a.device()); auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args); auto can_implement_status = gemm_op.can_implement(args);
STD_TORCH_CHECK( TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status); "Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM // Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr()); auto status = gemm_op.initialize(args, workspace.data_ptr());
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status, "Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " workspace_size=", workspace_size, " num_experts=", num_experts,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K); " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream); status = gemm_op.run(args, workspace.data_ptr(), stream);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
} }
void run_fp4_blockwise_scaled_group_mm_sm120( void run_fp4_blockwise_scaled_group_mm_sm120(
torch::stable::Tensor& output, const torch::stable::Tensor& a, torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::stable::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::stable::Tensor& alphas, const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
const torch::stable::Tensor& problem_sizes, int N, int K) {
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape = using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>; cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t; using ElementType = cutlass::float_e2m1_t;
@@ -481,40 +446,20 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::stable::Tensor a_ptrs = torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
std::nullopt, a.device()); torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor b_ptrs = torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
std::nullopt, a.device()); torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor out_ptrs = torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
std::nullopt, a.device()); torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::stable::Tensor a_scales_ptrs = torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>( run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -535,7 +480,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler; typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM; scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device_index(); hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts; static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] = cached_sm_counts[hw_info.device_id] =
@@ -578,36 +523,33 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
scheduler}; scheduler};
size_t workspace_size = Gemm::get_workspace_size(args); size_t workspace_size = Gemm::get_workspace_size(args);
auto workspace = auto const workspace_options =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
std::nullopt, a.device()); auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args); auto can_implement_status = gemm_op.can_implement(args);
STD_TORCH_CHECK( TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status); "Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM // Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr()); auto status = gemm_op.initialize(args, workspace.data_ptr());
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status, "Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " workspace_size=", workspace_size, " num_experts=", num_experts,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K); " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream); status = gemm_op.run(args, workspace.data_ptr(), stream);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
} }
template <typename OutType> template <typename OutType>
void run_fp4_blockwise_scaled_group_mm( void run_fp4_blockwise_scaled_group_mm(
torch::stable::Tensor& output, const torch::stable::Tensor& a, torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::stable::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::stable::Tensor& alphas, const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
const torch::stable::Tensor& problem_sizes, int N, int K) {
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
if (version_num >= 120 && version_num < 130) { if (version_num >= 120 && version_num < 130) {
@@ -625,7 +567,7 @@ void run_fp4_blockwise_scaled_group_mm(
return; return;
} }
#endif #endif
STD_TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ", "No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100 or 120"); version_num, ". Required capability: 100 or 120");
@@ -633,31 +575,26 @@ void run_fp4_blockwise_scaled_group_mm(
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#endif #endif
#define CHECK_TYPE(x, st, m) \ #define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \ TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \ #define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \ #define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \ CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m) CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(torch::stable::Tensor& output, void cutlass_fp4_group_mm(
const torch::stable::Tensor& a, torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::stable::Tensor& b, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::stable::Tensor& a_blockscale, const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::stable::Tensor& b_blockscales, const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets) {
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation // Input validation
@@ -665,26 +602,22 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale"); CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales"); CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas"); CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
STD_TORCH_CHECK( TORCH_CHECK(a_blockscale.dim() == 2,
a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m," "expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ", " k // group_size], observed rank: ",
a_blockscale.dim()) a_blockscale.dim())
STD_TORCH_CHECK(b_blockscales.dim() == 3, TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: " "expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ", " [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim()) b_blockscales.dim())
STD_TORCH_CHECK(problem_sizes.dim() == 2, TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
"problem_sizes must be a 2D tensor"); TORCH_CHECK(problem_sizes.size(1) == 3,
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)"); "problem_sizes must have the shape (num_experts, 3)");
STD_TORCH_CHECK( TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets"); "Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK( TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32."); "problem_sizes must be int32.");
int M = static_cast<int>(a.size(0)); int M = static_cast<int>(a.size(0));
@@ -692,7 +625,7 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
int E = static_cast<int>(b.size(0)); int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2)); int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) { if (output.scalar_type() == torch::kBFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>( run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K); expert_offsets, sf_offsets, M, N, K);
@@ -700,7 +633,7 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 120 && version_num < 130) { if (version_num >= 120 && version_num < 130) {
STD_TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ", false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
output.scalar_type()); output.scalar_type());
} }
@@ -710,7 +643,7 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
expert_offsets, sf_offsets, M, N, K); expert_offsets, sf_offsets, M, N, K);
} }
#else #else
STD_TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must " "No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 " "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
@@ -718,6 +651,6 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
#endif #endif
} }
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm)); m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
} }

View File

@@ -14,15 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/csrc/stable/tensor.h> #include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "nvfp4_utils.cuh" #include "nvfp4_utils.cuh"
@@ -326,28 +327,25 @@ void quant_impl(void* output, void* output_scale, void* input,
} // namespace vllm } // namespace vllm
/*Quantization entry for fp4 experts quantization*/ /*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) \ #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \ #define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = torch::headeronly::ScalarType::Half; constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16; constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = torch::headeronly::ScalarType::Float; constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = torch::headeronly::ScalarType::Int; constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte; constexpr auto UINT8 = at::ScalarType::Byte;
// Common validation for fp4 experts quantization entry points. // Common validation for fp4 experts quantization entry points.
static void validate_fp4_experts_quant_inputs( static void validate_fp4_experts_quant_inputs(
torch::stable::Tensor const& output, torch::Tensor const& output, torch::Tensor const& output_scale,
torch::stable::Tensor const& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::stable::Tensor const& input, torch::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& input_global_scale, torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
int64_t k) { int64_t k) {
CHECK_INPUT(output, "output"); CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale"); CHECK_INPUT(output_scale, "output_scale");
@@ -356,42 +354,41 @@ static void validate_fp4_experts_quant_inputs(
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts"); CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts"); CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
STD_TORCH_CHECK(output.dim() == 2); TORCH_CHECK(output.dim() == 2);
STD_TORCH_CHECK(output_scale.dim() == 2); TORCH_CHECK(output_scale.dim() == 2);
STD_TORCH_CHECK(input.dim() == 2); TORCH_CHECK(input.dim() == 2);
STD_TORCH_CHECK(input_global_scale.dim() == 1); TORCH_CHECK(input_global_scale.dim() == 1);
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1); TORCH_CHECK(input_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8) // output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32) // output_scale is int32 (four fp8 values are packed into one int32)
STD_TORCH_CHECK(output.scalar_type() == UINT8); TORCH_CHECK(output.scalar_type() == UINT8);
STD_TORCH_CHECK(output_scale.scalar_type() == INT); TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16; const int BLOCK_SIZE = 16;
STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0); auto n_experts = input_global_scale.size(0);
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output.size(0) == m_topk); TORCH_CHECK(output.size(0) == m_topk);
STD_TORCH_CHECK(output.size(1) == k / 2); TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE; int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4. // 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4; int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32 // 4 means 4 fp8 values are packed into one int32
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k); TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
} }
void scaled_fp4_experts_quant_sm1xxa( void scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,
torch::stable::Tensor const& input, torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::stable::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts) {
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0); auto m_topk = input.size(0);
auto k = input.size(1); auto k = input.size(1);
@@ -400,11 +397,11 @@ void scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k); output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0); auto n_experts = input_global_scale.size(0);
const torch::stable::accelerator::DeviceGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
input.get_device_index()); const cudaStream_t stream =
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index()); at::cuda::getCurrentCUDAStream(input.get_device());
VLLM_STABLE_DISPATCH_HALF_TYPES( VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>( vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
@@ -416,15 +413,14 @@ void scaled_fp4_experts_quant_sm1xxa(
} }
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,
torch::stable::Tensor const& input, torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::stable::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts) {
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0); auto m_topk = input.size(0);
// Input has gate || up layout, so k = input.size(1) / 2 // Input has gate || up layout, so k = input.size(1) / 2
auto k_times_2 = input.size(1); auto k_times_2 = input.size(1);
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2; auto k = k_times_2 / 2;
validate_fp4_experts_quant_inputs(output, output_scale, input, validate_fp4_experts_quant_inputs(output, output_scale, input,
@@ -432,11 +428,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k); output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0); auto n_experts = input_global_scale.size(0);
const torch::stable::accelerator::DeviceGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
input.get_device_index()); const cudaStream_t stream =
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index()); at::cuda::getCurrentCUDAStream(input.get_device());
VLLM_STABLE_DISPATCH_HALF_TYPES( VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] { input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>( vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(

View File

@@ -0,0 +1,134 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
torch::Tensor& output_sf,
torch::Tensor& input,
torch::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::empty(
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
torch::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::empty(
{sf_m, sf_n},
torch::TensorOptions().device(device).dtype(torch::kInt32));
} else {
output_sf = torch::empty(
{m, n / CVT_FP4_SF_VEC_SIZE},
torch::TensorOptions().device(device).dtype(torch::kUInt8));
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
torch::Tensor& input, torch::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -14,16 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/csrc/stable/tensor.h> #include <torch/all.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp8.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h" #include <cuda_fp8.h>
#include "libtorch_stable/dispatch_utils.h" #include "dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "launch_bounds_utils.h" #include "launch_bounds_utils.h"
@@ -173,18 +173,17 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm } // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output, void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::stable::Tensor const& input, torch::Tensor const& input,
torch::stable::Tensor const& output_sf, torch::Tensor const& output_sf,
torch::stable::Tensor const& input_sf, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) { bool is_sf_swizzled_layout) {
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1); int32_t n = input.size(1);
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK( TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16,
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4."); "Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount = int multiProcessorCount =
@@ -193,9 +192,8 @@ void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr()); auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr()); auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr()); auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const torch::stable::accelerator::DeviceGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
input.get_device_index()); auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
auto stream = get_current_cuda_stream(input.get_device_index());
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
@@ -215,10 +213,10 @@ void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
VLLM_STABLE_DISPATCH_HALF_TYPES( VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>( vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr, m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr), reinterpret_cast<uint32_t*>(output_ptr),
@@ -231,13 +229,13 @@ void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
VLLM_STABLE_DISPATCH_HALF_TYPES( VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false> vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr), reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out)); reinterpret_cast<uint32_t*>(sf_out));
}); });

View File

@@ -0,0 +1,67 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
#endif
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
#endif
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& A_sf,
const torch::Tensor& B_sf,
const torch::Tensor& alpha) {
// Make sure were on As device.
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) {
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
}

View File

@@ -14,9 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/csrc/stable/tensor.h> #include <torch/all.h>
#include "libtorch_stable/torch_utils.h" #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
@@ -126,9 +127,8 @@ struct Fp4GemmSm100 {
template <typename Config> template <typename Config>
typename Config::Gemm::Arguments args_from_options( typename Config::Gemm::Arguments args_from_options(
torch::stable::Tensor& D, torch::stable::Tensor const& A, at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
int64_t M, int64_t N, int64_t K) { int64_t M, int64_t N, int64_t K) {
using ElementA = typename Config::Gemm::ElementA; using ElementA = typename Config::Gemm::ElementA;
using ElementB = typename Config::Gemm::ElementB; using ElementB = typename Config::Gemm::ElementB;
@@ -174,20 +174,19 @@ typename Config::Gemm::Arguments args_from_options(
} }
template <typename Config> template <typename Config>
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A, void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, at::Tensor const& A_sf, at::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
torch::stable::Tensor const& alpha, int64_t m, int64_t n, cudaStream_t stream) {
int64_t k, cudaStream_t stream) {
typename Config::Gemm gemm; typename Config::Gemm gemm;
auto arguments = auto arguments =
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k); args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Config::Gemm::get_workspace_size(arguments); size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
auto workspace = auto const workspace_options =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
std::nullopt, A.device()); auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments)); CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -198,13 +197,12 @@ void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
// Dispatch function to select appropriate config based on M // Dispatch function to select appropriate config based on M
template <typename OutType> template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D, void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B,
torch::stable::Tensor const& B, torch::Tensor const& A_sf,
torch::stable::Tensor const& A_sf, torch::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, torch::Tensor const& alpha, int64_t m, int64_t n,
torch::stable::Tensor const& alpha, int64_t m, int64_t k, cudaStream_t stream) {
int64_t n, int64_t k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 16) { if (mp2 <= 16) {
@@ -224,65 +222,61 @@ void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
#else #else
template <typename OutType> template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D, void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B,
torch::stable::Tensor const& B, torch::Tensor const& A_sf,
torch::stable::Tensor const& A_sf, torch::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, torch::Tensor const& alpha, int64_t m, int64_t n,
torch::stable::Tensor const& alpha, int64_t m, int64_t k, cudaStream_t stream) {
int64_t n, int64_t k, cudaStream_t stream) { TORCH_CHECK(false,
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."); "a CUTLASS 3.8 source directory to enable support.");
} }
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \ #define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \ TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \ #define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \ #define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \ CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m) CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D, void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B,
torch::stable::Tensor const& B, torch::Tensor const& A_sf,
torch::stable::Tensor const& A_sf, torch::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, torch::Tensor const& alpha) {
torch::stable::Tensor const& alpha) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha"); CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix"); TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix"); TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1), TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.size(0), "x", "a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")"); A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
auto const m = A.size(0); auto const m = A.sizes()[0];
auto const n = B.size(0); auto const n = B.sizes()[0];
auto const k = A.size(1) * 2; auto const k = A.sizes()[1] * 2;
constexpr int alignment = 32; constexpr int alignment = 32;
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
alignment, ", but got a shape: (", A.size(0), "x", A.size(1), ", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, "."); "), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
alignment, ", but got b shape: (", B.size(0), "x", B.size(1), ", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128); int rounded_m = round_up(m, 128);
@@ -291,34 +285,33 @@ void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
// integer. // integer.
int rounded_k = round_up(k / 16, 4); int rounded_k = round_up(k / 16, 4);
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1), TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (", "scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x", A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
B_sf.size(1), ")"); "x", B_sf.sizes()[1], ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k, TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m, "scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x", "x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.size(1), ")"); A_sf.sizes()[1], ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k, TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n, "scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x", "x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.size(1), ")"); B_sf.sizes()[1], ")");
auto out_dtype = D.scalar_type(); auto out_dtype = D.dtype();
const torch::stable::accelerator::DeviceGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
A.get_device_index()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == torch::headeronly::ScalarType::Half) { if (out_dtype == at::ScalarType::Half) {
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
k, stream); k, stream);
} else if (out_dtype == torch::headeronly::ScalarType::BFloat16) { } else if (out_dtype == at::ScalarType::BFloat16) {
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
m, n, k, stream); m, n, k, stream);
} else { } else {
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
out_dtype, ")"); ")");
} }
} }

View File

@@ -14,9 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/csrc/stable/tensor.h> #include <torch/all.h>
#include "libtorch_stable/torch_utils.h" #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
@@ -34,19 +35,18 @@
using namespace cute; using namespace cute;
#define CHECK_TYPE(x, st, m) \ #define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \ TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \ #define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \ #define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \ CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m) CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
struct sm120_fp4_config_M256 { struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
@@ -109,13 +109,12 @@ struct Fp4GemmSm120 {
}; };
template <typename Gemm> template <typename Gemm>
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D, typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
torch::stable::Tensor const& A, at::Tensor const& B,
torch::stable::Tensor const& B, at::Tensor const& A_sf,
torch::stable::Tensor const& A_sf, at::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, torch::Tensor const& alpha, int M,
torch::stable::Tensor const& alpha, int N, int K) {
int M, int N, int K) {
using ElementA = typename Gemm::ElementA; using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB; using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
@@ -159,19 +158,18 @@ typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
} }
template <typename Gemm> template <typename Gemm>
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A, void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, at::Tensor const& A_sf, at::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, torch::Tensor const& alpha, int M, int N, int K,
torch::stable::Tensor const& alpha, int M, int N, int K,
cudaStream_t stream) { cudaStream_t stream) {
Gemm gemm; Gemm gemm;
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K); auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
size_t workspace_size = Gemm::get_workspace_size(arguments); size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace = auto const workspace_options =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
std::nullopt, A.device()); auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments)); CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -180,13 +178,12 @@ void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
} }
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D, void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B,
torch::stable::Tensor const& B, torch::Tensor const& A_sf,
torch::stable::Tensor const& A_sf, torch::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, torch::Tensor const& alpha, int m, int n,
torch::stable::Tensor const& alpha, int m, int k, cudaStream_t stream) {
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) { if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>( runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
@@ -197,13 +194,12 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
} }
} }
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D, void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B,
torch::stable::Tensor const& B, torch::Tensor const& A_sf,
torch::stable::Tensor const& A_sf, torch::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, torch::Tensor const& alpha, int m, int n,
torch::stable::Tensor const& alpha, int m, int k, cudaStream_t stream) {
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) { if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>( runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
@@ -214,12 +210,11 @@ void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
} }
} }
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D, void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::stable::Tensor const& A, torch::Tensor const& B,
torch::stable::Tensor const& B, torch::Tensor const& A_sf,
torch::stable::Tensor const& A_sf, torch::Tensor const& B_sf,
torch::stable::Tensor const& B_sf, torch::Tensor const& alpha) {
torch::stable::Tensor const& alpha) {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
@@ -227,25 +222,24 @@ void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha"); CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix"); TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix"); TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1), TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.size(0), "x", "a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")"); A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
auto const m = A.size(0); auto const m = A.sizes()[0];
auto const n = B.size(0); auto const n = B.sizes()[0];
auto const k = A.size(1) * 2; auto const k = A.sizes()[1] * 2;
constexpr int alignment = 32; constexpr int alignment = 32;
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
alignment, ", but got a shape: (", A.size(0), "x", A.size(1), ", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, "."); "), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
alignment, ", but got b shape: (", B.size(0), "x", B.size(1), ", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128); int rounded_m = round_up(m, 128);
@@ -254,38 +248,37 @@ void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
// integer. // integer.
int rounded_k = round_up(k / 16, 4); int rounded_k = round_up(k / 16, 4);
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1), TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (", "scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x", A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
B_sf.size(1), ")"); "x", B_sf.sizes()[1], ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k, TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m, "scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x", "x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.size(1), ")"); A_sf.sizes()[1], ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k, TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n, "scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x", "x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.size(1), ")"); B_sf.sizes()[1], ")");
auto out_dtype = D.scalar_type(); auto out_dtype = D.dtype();
const torch::stable::accelerator::DeviceGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
A.get_device_index()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == torch::headeronly::ScalarType::BFloat16) { if (out_dtype == at::ScalarType::BFloat16) {
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream); stream);
} else if (out_dtype == torch::headeronly::ScalarType::Half) { } else if (out_dtype == at::ScalarType::Half) {
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream); stream);
} else { } else {
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (", TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")"); out_dtype, ")");
} }
#else #else
STD_TORCH_CHECK(false, TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."); "a CUTLASS 3.8 source directory to enable support.");
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)

View File

@@ -20,7 +20,7 @@
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <utility> #include <utility>
#include "cuda_vec_utils.cuh" #include "../../cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \ #if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090 CUDA_VERSION >= 12090

View File

@@ -1,169 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../dispatch_utils.h"
#include "quant_conversions.cuh"
#include "../w8a8/fp8/common.cuh"
namespace vllm {
// Logic: one thread block per (token, group) pair
template <typename scalar_t, typename scalar_out_t, bool is_scale_transposed,
int32_t group_size>
__global__ void silu_and_mul_per_block_quant_kernel(
scalar_out_t* __restrict__ out, // Output: [num_tokens, hidden_size] in
// FP8/INT8
float* __restrict__ scales, // Output: [num_tokens, hidden_size /
// group_size] or [hidden_size / group_size,
// num_tokens]
scalar_t const* __restrict__ input, // Input: [num_tokens, hidden_size * 2]
float const* scale_ub, // Optional scale upper bound
int32_t const hidden_size // Output hidden size (input is 2x this)
) {
static_assert((group_size & (group_size - 1)) == 0,
"group_size must be a power of 2 for correct reduction");
// Grid: (num_tokens, num_groups)
int const token_idx = blockIdx.x;
int const group_idx = blockIdx.y;
int const tid = threadIdx.x; // tid in [0, group_size)
int const num_tokens = gridDim.x;
// Input layout: [gate || up] concatenated along last dimension
int const input_stride = hidden_size * 2;
int const group_start = group_idx * group_size;
// Pointers to this token's data
scalar_t const* token_input_gate =
input + token_idx * input_stride + group_start;
scalar_t const* token_input_up = token_input_gate + hidden_size;
scalar_out_t* token_output = out + token_idx * hidden_size + group_start;
// Scale pointer for this group
int const num_groups = gridDim.y;
float* group_scale_ptr = is_scale_transposed
? scales + group_idx * num_tokens + token_idx
: scales + token_idx * num_groups + group_idx;
// Shared memory for reduction (compile-time sized)
__shared__ float shared_max[group_size];
// Step 1: Each thread loads one element, computes SiLU, stores in register
float gate = static_cast<float>(token_input_gate[tid]);
float up = static_cast<float>(token_input_up[tid]);
// Compute SiLU(gate) * up
float sigmoid_gate = 1.0f / (1.0f + expf(-gate));
float silu_gate = gate * sigmoid_gate;
float result = silu_gate * up; // Keep in register
// Step 2: Reduce to find group max
shared_max[tid] = fabsf(result);
__syncthreads();
// Power-of-2 reduction (group_size guaranteed to be power of 2)
#pragma unroll
for (int stride = group_size / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]);
}
__syncthreads();
}
// Step 3: Compute scale (thread 0), broadcast via shared memory
if (tid == 0) {
float group_max = shared_max[0];
float const quant_range = quant_type_max_v<scalar_out_t>;
float group_scale = group_max / quant_range;
// Apply scale upper bound if provided
if (scale_ub != nullptr) {
group_scale = fminf(group_scale, *scale_ub);
}
// Use minimum safe scaling factor
group_scale = fmaxf(group_scale, min_scaling_factor<scalar_out_t>::val());
// Store scale to global memory
*group_scale_ptr = group_scale;
// Reuse shared_max[0] to broadcast scale
shared_max[0] = group_scale;
}
__syncthreads();
float group_scale = shared_max[0];
// Step 4: Quantize and write output
token_output[tid] =
vllm::ScaledQuant<scalar_out_t, false>::quant_fn(result, group_scale);
}
} // namespace vllm
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
"Input must be FP16 or BF16");
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size);
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
int32_t hidden_size = out.size(-1);
auto num_tokens = input.size(0);
int32_t num_groups = hidden_size / group_size;
TORCH_CHECK(input.size(-1) == hidden_size * 2,
"input last dim must be 2x output hidden_size");
TORCH_CHECK(hidden_size % group_size == 0,
"hidden_size must be divisible by group_size");
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens, num_groups);
dim3 block(group_size);
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_in_t = scalar_t;
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_out_t = scalar_t;
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
vllm::silu_and_mul_per_block_quant_kernel<
scalar_in_t, scalar_out_t, transpose_scale, gs>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_out_t>(),
scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
});
});
});
});
}

View File

@@ -154,7 +154,6 @@ struct MacheteCollectiveMma {
struct DispatchPolicy { struct DispatchPolicy {
constexpr static int Stages = PipelineStages; constexpr static int Stages = PipelineStages;
using ClusterShape = ClusterShape_MNK; using ClusterShape = ClusterShape_MNK;
using ArchTag = arch::Sm90;
using Schedule = KernelScheduleType; using Schedule = KernelScheduleType;
}; };

View File

@@ -108,15 +108,6 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS, "thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2], "group_blocks": [2],
}, },
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation # AWQ-INT4 with INT8 activation
{ {
"a_type": ["kS8"], "a_type": ["kS8"],

View File

@@ -591,9 +591,6 @@ torch::Tensor marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be", "When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
} }
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} }
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -327,9 +327,6 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) { if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2); s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (s_type == vllm::kFE8M0fnu) {
// MXFP8: FP8 weights with e8m0 microscaling block scales
static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16); static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) { } else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -337,7 +334,6 @@ __global__ void Marlin(
} }
constexpr bool is_a_8bit = a_type.size_bits() == 8; constexpr bool is_a_8bit = a_type.size_bits() == 8;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
if constexpr (!is_a_8bit) { if constexpr (!is_a_8bit) {
static_assert(std::is_same<scalar_t, c_scalar_t>::value); static_assert(std::is_same<scalar_t, c_scalar_t>::value);
} }
@@ -347,7 +343,7 @@ __global__ void Marlin(
b_type == vllm::kU4B8 || b_type == vllm::kU8B128; b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
// see comments of dequant.h for more details // see comments of dequant.h for more details
constexpr bool dequant_skip_flop = constexpr bool dequant_skip_flop =
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) || is_a_8bit || b_type == vllm::kFE4M3fn ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value || has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8); has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -559,8 +555,9 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order // Scale sizes/strides without act_order
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8); int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8); constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups = constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks ? thread_k_blocks / group_blocks
@@ -1000,7 +997,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (!is_8bit_scale) { if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else { } else {
@@ -1009,7 +1006,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} }
} else if (group_blocks >= b_sh_wr_iters) { } else if (group_blocks >= b_sh_wr_iters) {
if constexpr (!is_8bit_scale) { if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[1])[0] = reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0]; reinterpret_cast<int4*>(&frag_s[0])[0];
} else { } else {
@@ -1210,7 +1207,7 @@ __global__ void Marlin(
} }
} }
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) { if constexpr (b_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0]; int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1]; int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -2,10 +2,9 @@
// clang-format will break include orders // clang-format will break include orders
// clang-format off // clang-format off
#include <torch/csrc/stable/tensor.h> #include <torch/all.h>
#include <torch/csrc/stable/ops.h>
#include "libtorch_stable/torch_utils.h" #include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
@@ -26,14 +25,14 @@
namespace vllm::c3x { namespace vllm::c3x {
static inline cute::Shape<int, int, int, int> get_problem_shape( static inline cute::Shape<int, int, int, int> get_problem_shape(
torch::stable::Tensor const& a, torch::stable::Tensor const& b) { torch::Tensor const& a, torch::Tensor const& b) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1); int32_t m = a.size(0), n = b.size(1), k = a.size(1);
return {m, n, k, 1}; return {m, n, k, 1};
} }
template <typename GemmKernel> template <typename GemmKernel>
void cutlass_gemm_caller( void cutlass_gemm_caller(
torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape, torch::Device device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args, typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) { typename GemmKernel::TileSchedulerArguments scheduler = {}) {
@@ -51,20 +50,19 @@ void cutlass_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args)); CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args); size_t workspace_size = gemm_op.get_workspace_size(args);
auto workspace = auto const workspace_options =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, torch::TensorOptions().dtype(torch::kUInt8).device(device);
std::nullopt, device); auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = get_current_cuda_stream(device.index()); auto stream = at::cuda::getCurrentCUDAStream(device.index());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
} }
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::stable::Tensor& out, void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::stable::Tensor const& a, torch::Tensor const& b,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) { EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC; using ElementC = typename Gemm::ElementC;

View File

@@ -4,12 +4,13 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_azp_sm90_int8( void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::stable::Tensor& out, torch::stable::Tensor const& a, torch::Tensor const& b,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj, torch::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& azp, torch::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& bias) { std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
if (azp) { if (azp) {
return cutlass_scaled_mm_sm90_int8_epilogue< return cutlass_scaled_mm_sm90_int8_epilogue<
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,

View File

@@ -0,0 +1,23 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,7 +1,5 @@
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
@@ -132,10 +130,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::stable::Tensor const& b, torch::Tensor const& b,
torch::stable::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) { torch::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab; static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
@@ -202,11 +200,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out, void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::stable::Tensor const& a, torch::Tensor const& a,
torch::stable::Tensor const& b, torch::Tensor const& b,
torch::stable::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) { torch::Tensor const& b_scales) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());

View File

@@ -0,0 +1,23 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,7 +1,5 @@
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
@@ -140,10 +138,10 @@ struct sm120_blockwise_fp8_config_M64 {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::stable::Tensor const& b, torch::Tensor const& b,
torch::stable::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) { torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB; using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -198,11 +196,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out, void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
torch::stable::Tensor const& a, torch::Tensor const& a,
torch::stable::Tensor const& b, torch::Tensor const& b,
torch::stable::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) { torch::Tensor const& b_scales) {
int M = a.size(0); int M = a.size(0);
if (M <= 256) { if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm; using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;

View File

@@ -0,0 +1,24 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,7 +1,5 @@
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
@@ -103,10 +101,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::stable::Tensor const& b, torch::Tensor const& b,
torch::stable::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) { torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB; using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -122,7 +120,7 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
int32_t m = a.size(0), n = b.size(1), k = a.size(1); int32_t m = a.size(0), n = b.size(1), k = a.size(1);
STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
StrideA a_stride; StrideA a_stride;
StrideB b_stride; StrideB b_stride;
@@ -163,11 +161,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out, void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch::stable::Tensor const& a, torch::Tensor const& a,
torch::stable::Tensor const& b, torch::Tensor const& b,
torch::stable::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) { torch::Tensor const& b_scales) {
// TODO: better heuristics // TODO: better heuristics
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>, OutType, 1, 128, 128, Shape<_128, _128, _128>,

View File

@@ -1,57 +1,52 @@
#include <torch/csrc/stable/tensor.h> #include <torch/all.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc> template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::stable::Tensor& c, void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::stable::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::stable::Tensor const& b, torch::Tensor const& b_scales,
torch::stable::Tensor const& a_scales, std::optional<torch::Tensor> const& bias,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func, Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) { BlockwiseFunc blockwise_func) {
STD_TORCH_CHECK(a_scales.scalar_type() == TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
torch::headeronly::ScalarType::Float); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
int M = a.size(0), N = b.size(1), K = a.size(1); int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling // Standard per-tensor/per-token/per-channel scaling
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) { if (a.dtype() == torch::kFloat8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias); fp8_func(c, a, b, a_scales, b_scales, bias);
} else { } else {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char); TORCH_CHECK(a.dtype() == torch::kInt8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) { if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias); int8_func(c, a, b, a_scales, b_scales, bias);
} else { } else {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
STD_TORCH_CHECK( TORCH_CHECK(
false, "Int8 not supported on SM", version_num, false, "Int8 not supported on SM", version_num,
". Use FP8 quantization instead, or run on older arch (SM < 100)."); ". Use FP8 quantization instead, or run on older arch (SM < 100).");
} }
} }
} else { } else {
STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 90) { if (version_num >= 90) {
STD_TORCH_CHECK( TORCH_CHECK(
a.size(0) == a_scales.size(0) && a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
"a_scale_group_shape must be [1, 128]."); "a_scale_group_shape must be [1, 128].");
STD_TORCH_CHECK( TORCH_CHECK(
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128]."); "b_scale_group_shape must be [128, 128].");
} }
STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales); blockwise_func(c, a, b, a_scales, b_scales);
} }
} }

Some files were not shown because too many files have changed in this diff Show More