Compare commits
247 Commits
v0.18.2rc0
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5de19ff9a | ||
|
|
edee96519a | ||
|
|
adaabb8a55 | ||
|
|
f7cad67412 | ||
|
|
a8134aef4e | ||
|
|
2800706f06 | ||
|
|
0d310ffbeb | ||
|
|
d5f75fdf50 | ||
|
|
827268e98d | ||
|
|
56e19d7ee2 | ||
|
|
9036d4c464 | ||
|
|
a8c6ee9b78 | ||
|
|
3b1d9c3156 | ||
|
|
54d244f28f | ||
|
|
6c749399b7 | ||
|
|
91eea72330 | ||
|
|
df2503e125 | ||
|
|
c8d98f81f6 | ||
|
|
d87fb264df | ||
|
|
66c079ae83 | ||
|
|
b6c9be509e | ||
|
|
ed733802f0 | ||
|
|
8a34c5087a | ||
|
|
ed2f282bc8 | ||
|
|
9e78555743 | ||
|
|
e80e633927 | ||
|
|
490f17d0c7 | ||
|
|
2e98406048 | ||
|
|
ef5a226819 | ||
|
|
aec18492d0 | ||
|
|
2a49284c8a | ||
|
|
d37b378762 | ||
|
|
92fbec391b | ||
|
|
2f41d6c063 | ||
|
|
3aecdf08b4 | ||
|
|
eb4205fee5 | ||
|
|
83aea2147f | ||
|
|
2e9034c998 | ||
|
|
8332078cfd | ||
|
|
ba4a78eb5d | ||
|
|
f3c7941ec8 | ||
|
|
3352bf8b03 | ||
|
|
7c94ae16c6 | ||
|
|
ad05edfbca | ||
|
|
2018137242 | ||
|
|
a776a48b1c | ||
|
|
8477fe427d | ||
|
|
e24e0a43a4 | ||
|
|
b55d830ec7 | ||
|
|
75e01a39a1 | ||
|
|
512c5eb455 | ||
|
|
13151a4df4 | ||
|
|
56c976c1b5 | ||
|
|
d74a306c4b | ||
|
|
0e9f0a516c | ||
|
|
8904fc4d19 | ||
|
|
1a2c17634e | ||
|
|
308cec5864 | ||
|
|
4e2ab1861d | ||
|
|
140cbb1186 | ||
|
|
6155bbd1dd | ||
|
|
78434b923c | ||
|
|
2488d1dca2 | ||
|
|
d734445fcd | ||
|
|
927975ead8 | ||
|
|
9ea7d670d8 | ||
|
|
7b80cd8ac3 | ||
|
|
2111997f96 | ||
|
|
5af684c319 | ||
|
|
d521dcdbcc | ||
|
|
5daf62271d | ||
|
|
ad3304425b | ||
|
|
70406eb1dc | ||
|
|
08bfedc152 | ||
|
|
0102bd2f4c | ||
|
|
83d09d36b5 | ||
|
|
92b9afeecd | ||
|
|
7310555482 | ||
|
|
96b5004b71 | ||
|
|
98e1a43af7 | ||
|
|
729eb59f60 | ||
|
|
6e1100889e | ||
|
|
edcc37a8ce | ||
|
|
79df4a794d | ||
|
|
7c139ab23f | ||
|
|
0be9516ea4 | ||
|
|
7b9de7c892 | ||
|
|
dd9342e6bc | ||
|
|
8060bb0333 | ||
|
|
da4c0e4db9 | ||
|
|
a9a0e0551f | ||
|
|
5c35517a3e | ||
|
|
a435e3108d | ||
|
|
2df2c85be4 | ||
|
|
62095e82c1 | ||
|
|
b2b2c5239e | ||
|
|
00d7b497b3 | ||
|
|
9c81f35b1a | ||
|
|
f186cfe75e | ||
|
|
dfa5062a8f | ||
|
|
e8ebbdde83 | ||
|
|
94fbb09894 | ||
|
|
419e73cdfa | ||
|
|
f01482408c | ||
|
|
bfdc0a3a99 | ||
|
|
93bada494f | ||
|
|
608914de30 | ||
|
|
4ae218c122 | ||
|
|
f40d9879f2 | ||
|
|
47e605092b | ||
|
|
e69a265135 | ||
|
|
fef56c1855 | ||
|
|
c5e3454e5a | ||
|
|
f6983f01de | ||
|
|
780ba37458 | ||
|
|
9570654c6d | ||
|
|
d56e952239 | ||
|
|
56de443db1 | ||
|
|
4dd49b06f8 | ||
|
|
f53fa26e05 | ||
|
|
1af6f78ae5 | ||
|
|
228023b3a5 | ||
|
|
9a528260ef | ||
|
|
968ed02ace | ||
|
|
7d266abb22 | ||
|
|
156405d243 | ||
|
|
99e5539a67 | ||
|
|
a88ce94bbb | ||
|
|
2a36d8fb72 | ||
|
|
93726b2a1c | ||
|
|
8617f8676b | ||
|
|
06fd9ffcc4 | ||
|
|
cab4064cd5 | ||
|
|
062f1a2d70 | ||
|
|
81994e1d0e | ||
|
|
4b506ff90a | ||
|
|
5875bb2e9c | ||
|
|
f0d3ad9f3e | ||
|
|
121ea5a21f | ||
|
|
ab79863e6c | ||
|
|
5f1de2b14b | ||
|
|
a5a623d961 | ||
|
|
f8c3af2d85 | ||
|
|
50cd5674b3 | ||
|
|
7b1a7423be | ||
|
|
97f92c6b47 | ||
|
|
46f02e00f2 | ||
|
|
6b4872240f | ||
|
|
580090db6b | ||
|
|
cb10b7e80b | ||
|
|
bf8b022e60 | ||
|
|
40ee64c00e | ||
|
|
1b117cb0ac | ||
|
|
abebd9323d | ||
|
|
25f2b55319 | ||
|
|
cb4ff07f8b | ||
|
|
a7d79fa133 | ||
|
|
fa9e68022d | ||
|
|
5506435419 | ||
|
|
311c981647 | ||
|
|
21d7ecc5b0 | ||
|
|
4729b90838 | ||
|
|
8b141ed8c3 | ||
|
|
2ad7c0335f | ||
|
|
201d2ea5bf | ||
|
|
103f0de565 | ||
|
|
32e0c0bfa2 | ||
|
|
4a06e1246e | ||
|
|
3bc2734dd0 | ||
|
|
1f5ec2889c | ||
|
|
ee3cf45739 | ||
|
|
05e68e1f81 | ||
|
|
771913e4a0 | ||
|
|
71a9125c67 | ||
|
|
66e86f1dbd | ||
|
|
bb39382b2b | ||
|
|
7b743ba953 | ||
|
|
188defbd0b | ||
|
|
08ed2b9688 | ||
|
|
ecd5443dbc | ||
|
|
58262dec6e | ||
|
|
cb3935a8fc | ||
|
|
82a006beeb | ||
|
|
a9b4f07ba2 | ||
|
|
d9408ffba3 | ||
|
|
16a65e4173 | ||
|
|
c0817e4d39 | ||
|
|
dfe5e31689 | ||
|
|
2ce3d0ce36 | ||
|
|
4eefbf9609 | ||
|
|
551b3fb39f | ||
|
|
c6f722b93e | ||
|
|
9bd7231106 | ||
|
|
73f48ce559 | ||
|
|
3aab680e3e | ||
|
|
5a2d420c17 | ||
|
|
5f96f9aff1 | ||
|
|
694449050f | ||
|
|
6241521dd2 | ||
|
|
1785dc5501 | ||
|
|
54500546ac | ||
|
|
de5e6c44c6 | ||
|
|
cb268e4e55 | ||
|
|
6183cae1bd | ||
|
|
c09ad767cd | ||
|
|
c9a9db0e02 | ||
|
|
cbe7d18096 | ||
|
|
db5d0719e1 | ||
|
|
dc0428ebb8 | ||
|
|
148c2072ec | ||
|
|
2f5c3c1ec0 | ||
|
|
fa246d5231 | ||
|
|
7cf56a59a2 | ||
|
|
5e30e9b9a9 | ||
|
|
582340f273 | ||
|
|
992368522f | ||
|
|
58ee614221 | ||
|
|
f9f6a9097a | ||
|
|
c75a313824 | ||
|
|
4f6eed3bd4 | ||
|
|
36d7f19897 | ||
|
|
2d725b89c5 | ||
|
|
ef53395e2c | ||
|
|
eb47454987 | ||
|
|
116f4be405 | ||
|
|
7b01d97a22 | ||
|
|
17b72fd1c8 | ||
|
|
c49497726b | ||
|
|
cb0b443274 | ||
|
|
40bb175027 | ||
|
|
0fab52f0aa | ||
|
|
91e4521f9f | ||
|
|
31a719bcd3 | ||
|
|
2e56975657 | ||
|
|
36f1dc19ae | ||
|
|
3dc01ef352 | ||
|
|
cc671cb110 | ||
|
|
856589ed9a | ||
|
|
517b769b58 | ||
|
|
d9b90a07ac | ||
|
|
598190aac3 | ||
|
|
b779eb3363 | ||
|
|
077a9a8e37 | ||
|
|
07edd551cc | ||
|
|
7c080dd3c5 | ||
|
|
0dd25a44ea | ||
|
|
3896e021a0 |
@@ -5,7 +5,6 @@ steps:
|
||||
depends_on: []
|
||||
device: amd_cpu
|
||||
no_plugin: true
|
||||
soft_fail: true
|
||||
commands:
|
||||
- >
|
||||
docker build
|
||||
|
||||
@@ -56,9 +56,9 @@ steps:
|
||||
'cd tests &&
|
||||
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py &&
|
||||
pytest -v -s v1/engine --ignore=v1/engine/test_output_processor.py &&
|
||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py &&
|
||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py -k "not test_topk_only and not test_topp_only and not test_topk_and_topp" &&
|
||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py --ignore=v1/worker/test_worker_memory_snapshot.py &&
|
||||
pytest -v -s v1/structured_output &&
|
||||
pytest -v -s v1/test_serial_utils.py &&
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py --ignore=v1/spec_decode/test_acceptance_length.py &&
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py'
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py --ignore=v1/kv_connector/unit/test_hf3fs_client.py --ignore=v1/kv_connector/unit/test_hf3fs_connector.py --ignore=v1/kv_connector/unit/test_hf3fs_metadata_server.py'
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-mmlupro-vllm-baseline.sh -m meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 -l 250 -t 8 -f 5
|
||||
model_name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
required_gpu_arch:
|
||||
- gfx942
|
||||
- gfx950
|
||||
tasks:
|
||||
- name: "mmlu_pro"
|
||||
metrics:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# For vllm script, with -t option (tensor parallel size)
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -l 1319 -t 1
|
||||
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
||||
required_gpu_arch:
|
||||
- gfx942
|
||||
- gfx950
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
model_name: "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8"
|
||||
required_gpu_arch:
|
||||
- gfx942
|
||||
- gfx950
|
||||
tasks:
|
||||
- name: "mmlu_pro"
|
||||
metrics:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
Qwen2.5-1.5B-Instruct.yaml
|
||||
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
|
||||
|
||||
@@ -13,6 +13,7 @@ import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import lm_eval
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@@ -89,9 +90,40 @@ def launch_lm_eval(eval_config, tp_size):
|
||||
return results
|
||||
|
||||
|
||||
def _check_rocm_gpu_arch_requirement(eval_config):
|
||||
"""Skip the test if the model requires a ROCm GPU arch not present.
|
||||
|
||||
Model YAML configs can specify::
|
||||
|
||||
required_gpu_arch:
|
||||
- gfx942
|
||||
- gfx950
|
||||
|
||||
The check only applies on ROCm. On other platforms (e.g. CUDA) the
|
||||
field is ignored so that shared config files work for both NVIDIA and
|
||||
AMD CI pipelines.
|
||||
"""
|
||||
required_archs = eval_config.get("required_gpu_arch")
|
||||
if not required_archs:
|
||||
return
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
from vllm.platforms.rocm import _GCN_ARCH # noqa: E402
|
||||
|
||||
if not any(arch in _GCN_ARCH for arch in required_archs):
|
||||
pytest.skip(
|
||||
f"Model requires GPU arch {required_archs}, "
|
||||
f"but detected arch is '{_GCN_ARCH}'"
|
||||
)
|
||||
|
||||
|
||||
def test_lm_eval_correctness_param(config_filename, tp_size):
|
||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||
|
||||
_check_rocm_gpu_arch_requirement(eval_config)
|
||||
|
||||
results = launch_lm_eval(eval_config, tp_size)
|
||||
|
||||
rtol = eval_config.get("rtol", DEFAULT_RTOL)
|
||||
|
||||
@@ -19,7 +19,7 @@ has_new_python=$($PYTHON -c "print(1 if __import__('sys').version_info >= (3,12)
|
||||
if [[ "$has_new_python" -eq 0 ]]; then
|
||||
# use new python from docker
|
||||
docker pull python:3-slim
|
||||
PYTHON="docker run --rm -v $(pwd):/app -w /app python:3-slim python3"
|
||||
PYTHON="docker run --rm -u $(id -u):$(id -g) -v $(pwd):/app -w /app python:3-slim python3"
|
||||
fi
|
||||
|
||||
echo "Using python interpreter: $PYTHON"
|
||||
|
||||
@@ -35,23 +35,6 @@ export PYTHONPATH=".."
|
||||
# Helper Functions
|
||||
###############################################################################
|
||||
|
||||
wait_for_clean_gpus() {
|
||||
local timeout=${1:-300}
|
||||
local start=$SECONDS
|
||||
echo "--- Waiting for clean GPU state (timeout: ${timeout}s)"
|
||||
while true; do
|
||||
if grep -q clean /opt/amdgpu/etc/gpu_state; then
|
||||
echo "GPUs state is \"clean\""
|
||||
return
|
||||
fi
|
||||
if (( SECONDS - start >= timeout )); then
|
||||
echo "Error: GPUs did not reach clean state within ${timeout}s" >&2
|
||||
exit 1
|
||||
fi
|
||||
sleep 3
|
||||
done
|
||||
}
|
||||
|
||||
cleanup_docker() {
|
||||
# Get Docker's root directory
|
||||
docker_root=$(docker info -f '{{.DockerRootDir}}')
|
||||
@@ -365,19 +348,12 @@ apply_rocm_test_overrides() {
|
||||
###############################################################################
|
||||
|
||||
# --- GPU initialization ---
|
||||
echo "--- Confirming Clean Initial State"
|
||||
wait_for_clean_gpus
|
||||
|
||||
echo "--- ROCm info"
|
||||
rocminfo
|
||||
|
||||
# --- Docker housekeeping ---
|
||||
cleanup_docker
|
||||
|
||||
echo "--- Resetting GPUs"
|
||||
echo "reset" > /opt/amdgpu/etc/gpu_state
|
||||
wait_for_clean_gpus
|
||||
|
||||
# --- Pull test image ---
|
||||
echo "--- Pulling container"
|
||||
image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
|
||||
@@ -23,22 +23,22 @@ if [ "$failed_req" -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "--- DP+TP"
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
|
||||
server_pid=$!
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--num-prompts 20 \
|
||||
--result-dir ./test_results \
|
||||
--result-filename dp_pp.json \
|
||||
--save-result \
|
||||
--endpoint /v1/completions
|
||||
kill -s SIGTERM $server_pid; wait $server_pid || true
|
||||
failed_req=$(jq '.failed' ./test_results/dp_pp.json)
|
||||
if [ "$failed_req" -ne 0 ]; then
|
||||
echo "Some requests were failed!"
|
||||
exit 1
|
||||
fi
|
||||
#echo "--- DP+TP"
|
||||
#vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
|
||||
#server_pid=$!
|
||||
#timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||
#vllm bench serve \
|
||||
# --backend vllm \
|
||||
# --dataset-name random \
|
||||
# --model meta-llama/Llama-3.2-3B-Instruct \
|
||||
# --num-prompts 20 \
|
||||
# --result-dir ./test_results \
|
||||
# --result-filename dp_pp.json \
|
||||
# --save-result \
|
||||
# --endpoint /v1/completions
|
||||
#kill -s SIGTERM $server_pid; wait $server_pid || true
|
||||
#failed_req=$(jq '.failed' ./test_results/dp_pp.json)
|
||||
#if [ "$failed_req" -ne 0 ]; then
|
||||
# echo "Some requests were failed!"
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
@@ -42,6 +42,7 @@ docker run \
|
||||
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192
|
||||
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
||||
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
||||
python3 examples/basic/offline_inference/generate.py --model OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc --block-size 64 --enforce-eager --max-model-len 8192
|
||||
cd tests
|
||||
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
|
||||
pytest -v -s v1/engine
|
||||
@@ -49,6 +50,6 @@ docker run \
|
||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py --ignore=v1/worker/test_worker_memory_snapshot.py
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py --ignore=v1/spec_decode/test_acceptance_length.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py -k "not (test_register_kv_caches and FLASH_ATTN and True)"
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py --ignore=v1/kv_connector/unit/test_hf3fs_client.py --ignore=v1/kv_connector/unit/test_hf3fs_connector.py --ignore=v1/kv_connector/unit/test_hf3fs_metadata_server.py
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
'
|
||||
|
||||
@@ -751,6 +751,7 @@ steps:
|
||||
timeout_in_minutes: 180
|
||||
mirror_hardwares: [amdexperimental, amdproduction, amdgfx90anightly, amdmi250]
|
||||
agent_pool: mi250_1
|
||||
optional: true
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
@@ -790,7 +791,7 @@ steps:
|
||||
- tests/kernels/helion/
|
||||
- vllm/platforms/rocm.py
|
||||
commands:
|
||||
- pip install helion
|
||||
- pip install helion==0.3.3
|
||||
- pytest -v -s kernels/helion/
|
||||
|
||||
|
||||
@@ -2035,7 +2036,6 @@ steps:
|
||||
timeout_in_minutes: 38
|
||||
mirror_hardwares: [amdexperimental, amdproduction, amdgfx942nightly, amdmi325]
|
||||
agent_pool: mi325_1
|
||||
optional: true
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
@@ -2165,7 +2165,15 @@ steps:
|
||||
- vllm/platforms/rocm.py
|
||||
- tests/quantization
|
||||
commands:
|
||||
- uv pip install --system torchao==0.14.1
|
||||
|
||||
# temporary install here since we need nightly, will move to requirements/test.in
|
||||
# after torchao 0.12 release, and pin a working version of torchao nightly here
|
||||
|
||||
# since torchao nightly is only compatible with torch nightly currently
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
# TODO(jerryzh168): resolve the above comment
|
||||
- uv pip install --system torchao==0.17.0
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
@@ -2690,6 +2698,24 @@ steps:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
|
||||
- label: LM Eval Small Models (MI325) # TBD
|
||||
timeout_in_minutes: 180
|
||||
mirror_hardwares: [amdexperimental, amdproduction, amdgfx942nightly, amdmi325]
|
||||
agent_pool: mi325_1
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- vllm/model_executor/models/
|
||||
- vllm/model_executor/model_loader/
|
||||
- vllm/v1/attention/backends/
|
||||
- vllm/v1/attention/selector.py
|
||||
- vllm/_aiter_ops.py
|
||||
- vllm/platforms/rocm.py
|
||||
commands:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small-rocm.txt
|
||||
|
||||
|
||||
- label: LM Eval Small Models (B200-MI325) # TBD
|
||||
timeout_in_minutes: 180
|
||||
mirror_hardwares: [amdexperimental, amdproduction, amdgfx942nightly, amdmi325]
|
||||
@@ -2906,10 +2932,10 @@ steps:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
|
||||
##### .buildkite/test_areas/compile.yaml #####
|
||||
# Slowly setting up the tests so that it is also easier for the
|
||||
# Slowly setting up the tests so that it is also easier for the
|
||||
# CI team to review and upstream to the pipelinev2.
|
||||
# The following tests are important for vLLM IR Ops refactoring,
|
||||
# which affects fusion passes on ROCm. So we have to
|
||||
# which affects fusion passes on ROCm. So we have to
|
||||
# enable them as as soon as possible.
|
||||
|
||||
## TODO: Enable the test in this group
|
||||
@@ -2988,7 +3014,7 @@ steps:
|
||||
|
||||
## There are no ops on ROCm for these tests.
|
||||
## The test still passes but the logs are not useful.
|
||||
## fused ops just call torch.ops.symm_mem which
|
||||
## fused ops just call torch.ops.symm_mem which
|
||||
## exists in ROCm even though they don't work
|
||||
# - label: AsyncTP Correctness Tests (2xH100-2xMI325)
|
||||
# - label: Fusion E2E TP2 Quick (H100-MI325)
|
||||
@@ -3320,7 +3346,7 @@ steps:
|
||||
- vllm/_aiter_ops.py
|
||||
- vllm/platforms/rocm.py
|
||||
commands:
|
||||
- uv pip install --system torchao==0.14.1
|
||||
- uv pip install --system torchao==0.17.0
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ depends_on:
|
||||
steps:
|
||||
- label: Basic Correctness
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
|
||||
@@ -4,6 +4,7 @@ depends_on:
|
||||
steps:
|
||||
- label: Benchmarks CLI Test
|
||||
timeout_in_minutes: 20
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/benchmarks/
|
||||
|
||||
@@ -72,6 +72,7 @@ steps:
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
|
||||
- tests/compile/passes/test_fusion_attn.py
|
||||
- tests/compile/passes/test_mla_attn_quant_fusion.py
|
||||
- tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/fullgraph/test_full_graph.py
|
||||
@@ -79,6 +80,7 @@ steps:
|
||||
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
|
||||
- pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py
|
||||
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_devices=2 is not set
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
|
||||
@@ -4,6 +4,7 @@ depends_on:
|
||||
steps:
|
||||
- label: Platform Tests (CUDA)
|
||||
timeout_in_minutes: 15
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/cuda
|
||||
|
||||
@@ -294,3 +294,23 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pp_cudagraph.py
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- label: RayExecutorV2 (4 GPUs)
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 4
|
||||
source_file_dependencies:
|
||||
- vllm/v1/executor/ray_executor_v2.py
|
||||
- vllm/v1/executor/abstract.py
|
||||
- vllm/v1/executor/multiproc_executor.py
|
||||
- tests/distributed/test_ray_v2_executor.py
|
||||
- tests/distributed/test_ray_v2_executor_e2e.py
|
||||
- tests/distributed/test_pipeline_parallel.py
|
||||
- tests/basic_correctness/test_basic_correctness.py
|
||||
commands:
|
||||
- export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1
|
||||
- export NCCL_CUMEM_HOST_ENABLE=0
|
||||
- pytest -v -s distributed/test_ray_v2_executor.py
|
||||
- pytest -v -s distributed/test_ray_v2_executor_e2e.py
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py -k "ray"
|
||||
- TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray"
|
||||
|
||||
@@ -4,6 +4,7 @@ depends_on:
|
||||
steps:
|
||||
- label: Engine
|
||||
timeout_in_minutes: 15
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
@@ -25,6 +26,7 @@ steps:
|
||||
|
||||
- label: e2e Scheduling (1 GPU)
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/v1/
|
||||
- tests/v1/e2e/general/
|
||||
|
||||
@@ -61,6 +61,7 @@ steps:
|
||||
|
||||
- label: Entrypoints Integration (API Server openai - Part 3)
|
||||
timeout_in_minutes: 50
|
||||
device: h200_18gb
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@@ -105,6 +106,7 @@ steps:
|
||||
|
||||
- label: OpenAI API Correctness
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/entrypoints/openai/
|
||||
|
||||
@@ -4,6 +4,7 @@ depends_on:
|
||||
steps:
|
||||
- label: EPLB Algorithm
|
||||
timeout_in_minutes: 15
|
||||
device: h200_18gb
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/eplb
|
||||
|
||||
@@ -2,15 +2,25 @@ group: Kernels
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: vLLM IR Tests
|
||||
timeout_in_minutes: 10
|
||||
device: h200_18gb
|
||||
working_dir: "/vllm-workspace/"
|
||||
source_file_dependencies:
|
||||
- vllm/ir
|
||||
- vllm/kernels
|
||||
commands:
|
||||
- pytest -v -s tests/ir
|
||||
- pytest -v -s tests/kernels/ir
|
||||
|
||||
- label: Kernels Core Operation Test
|
||||
timeout_in_minutes: 75
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- tests/kernels/core
|
||||
- tests/kernels/test_top_k_per_row.py
|
||||
- tests/kernels/test_concat_mla_q.py
|
||||
commands:
|
||||
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py
|
||||
- pytest -v -s kernels/core kernels/test_concat_mla_q.py
|
||||
|
||||
- label: Kernels Attention Test %N
|
||||
timeout_in_minutes: 35
|
||||
@@ -19,6 +29,7 @@ steps:
|
||||
- vllm/v1/attention
|
||||
# TODO: remove this dependency (https://github.com/vllm-project/vllm/issues/32267)
|
||||
- vllm/model_executor/layers/attention
|
||||
- vllm/utils/flashinfer.py
|
||||
- tests/kernels/attention
|
||||
commands:
|
||||
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
@@ -95,6 +106,7 @@ steps:
|
||||
- vllm/v1/attention/backends/mla/flashinfer_mla.py
|
||||
- vllm/v1/attention/selector.py
|
||||
- vllm/platforms/cuda.py
|
||||
- tests/kernels/test_top_k_per_row.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/basic/offline_inference/chat.py
|
||||
@@ -105,6 +117,7 @@ steps:
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
||||
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
|
||||
- pytest -v -s tests/kernels/test_top_k_per_row.py
|
||||
# Quantization
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
@@ -129,7 +142,7 @@ steps:
|
||||
- vllm/utils/import_utils.py
|
||||
- tests/kernels/helion/
|
||||
commands:
|
||||
- pip install helion
|
||||
- pip install helion==0.3.3
|
||||
- pytest -v -s kernels/helion/
|
||||
|
||||
|
||||
@@ -168,3 +181,21 @@ steps:
|
||||
- pytest -v -s kernels/moe/test_flashinfer_moe.py
|
||||
- pytest -v -s kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s kernels/moe/test_ocp_mx_moe.py
|
||||
|
||||
|
||||
- label: Kernels FusedMoE Layer Test (2 H100s)
|
||||
timeout_in_minutes: 90
|
||||
device: h100
|
||||
num_devices: 2
|
||||
optional: true
|
||||
commands:
|
||||
- pytest -v -s kernels/moe/test_moe_layer.py
|
||||
|
||||
|
||||
- label: Kernels FusedMoE Layer Test (2 B200s)
|
||||
timeout_in_minutes: 90
|
||||
device: b200
|
||||
num_devices: 2
|
||||
optional: true
|
||||
commands:
|
||||
- pytest -v -s kernels/moe/test_moe_layer.py
|
||||
|
||||
@@ -19,6 +19,7 @@ steps:
|
||||
|
||||
- label: V1 Sample + Logits
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/v1/sample
|
||||
@@ -86,6 +87,7 @@ steps:
|
||||
|
||||
- label: Regression
|
||||
timeout_in_minutes: 20
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
@@ -174,6 +176,7 @@ steps:
|
||||
- tests/renderers
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
- tests/tokenizers_
|
||||
- tests/reasoning
|
||||
- tests/tool_parsers
|
||||
- tests/transformers_utils
|
||||
- tests/config
|
||||
@@ -187,6 +190,7 @@ steps:
|
||||
- pytest -v -s -m 'cpu_test' multimodal
|
||||
- pytest -v -s renderers
|
||||
- pytest -v -s tokenizers_
|
||||
- pytest -v -s reasoning --ignore=reasoning/test_seedoss_reasoning_parser.py --ignore=reasoning/test_glm4_moe_reasoning_parser.py --ignore=reasoning/test_gemma4_reasoning_parser.py
|
||||
- pytest -v -s tool_parsers
|
||||
- pytest -v -s transformers_utils
|
||||
- pytest -v -s config
|
||||
|
||||
@@ -78,7 +78,6 @@ steps:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py -k "not ray"
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
|
||||
|
||||
# These require fix https://github.com/vllm-project/vllm/pull/36280
|
||||
- label: Model Runner V2 Pipeline Parallelism (4 GPUs)
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@@ -101,11 +100,13 @@ steps:
|
||||
- vllm/v1/worker/gpu/
|
||||
- vllm/v1/worker/gpu_worker.py
|
||||
- tests/v1/spec_decode/test_max_len.py
|
||||
- tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
|
||||
- tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
|
||||
- tests/v1/e2e/spec_decode/test_spec_decode.py
|
||||
commands:
|
||||
- set -x
|
||||
- export VLLM_USE_V2_MODEL_RUNNER=1
|
||||
- pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
|
||||
- pytest -v -s v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
|
||||
- pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
|
||||
- pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"
|
||||
|
||||
@@ -4,6 +4,7 @@ depends_on:
|
||||
steps:
|
||||
- label: Basic Models Tests (Initialization)
|
||||
timeout_in_minutes: 45
|
||||
device: h200_18gb
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
|
||||
@@ -18,5 +18,6 @@ steps:
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
|
||||
- pytest models/multimodal/generation/test_phi4siglip.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/generation/test_phi4siglip.py
|
||||
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
|
||||
|
||||
@@ -38,7 +38,7 @@ steps:
|
||||
# Install fast path packages for testing against transformers
|
||||
# Note: also needed to run plamo2 model in vLLM
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.3.0'
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2'
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.6.0'
|
||||
# Shard hybrid language model tests
|
||||
- pytest -v -s models/language/generation -m hybrid_model --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB
|
||||
parallelism: 2
|
||||
@@ -53,7 +53,7 @@ steps:
|
||||
# Install fast path packages for testing against transformers
|
||||
# Note: also needed to run plamo2 model in vLLM
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.3.0'
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2'
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.6.0'
|
||||
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
|
||||
mirror:
|
||||
amd:
|
||||
@@ -67,6 +67,7 @@ steps:
|
||||
|
||||
- label: Language Models Test (PPL)
|
||||
timeout_in_minutes: 110
|
||||
device: h200_18gb
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@@ -90,6 +91,7 @@ steps:
|
||||
|
||||
- label: Language Models Test (MTEB)
|
||||
timeout_in_minutes: 110
|
||||
device: h200_18gb
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
|
||||
@@ -4,6 +4,7 @@ depends_on:
|
||||
steps:
|
||||
- label: "Multi-Modal Models (Standard) 1: qwen2"
|
||||
timeout_in_minutes: 45
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
@@ -19,6 +20,7 @@ steps:
|
||||
|
||||
- label: "Multi-Modal Models (Standard) 2: qwen3 + gemma"
|
||||
timeout_in_minutes: 45
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
@@ -54,7 +56,8 @@ steps:
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/generation/test_ultravox.py --ignore models/multimodal/generation/test_qwen2_5_vl.py --ignore models/multimodal/generation/test_qwen2_vl.py --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/generation/test_ultravox.py --ignore models/multimodal/generation/test_qwen2_5_vl.py --ignore models/multimodal/generation/test_qwen2_vl.py --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/generation/test_memory_leak.py --ignore models/multimodal/processing
|
||||
- pytest models/multimodal/generation/test_memory_leak.py -m core_model
|
||||
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
mirror:
|
||||
amd:
|
||||
@@ -77,6 +80,7 @@ steps:
|
||||
|
||||
- label: Multi-Modal Processor # 44min
|
||||
timeout_in_minutes: 60
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
@@ -131,6 +135,7 @@ steps:
|
||||
|
||||
- label: Multi-Modal Models (Extended Pooling)
|
||||
optional: true
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal/pooling
|
||||
|
||||
@@ -49,6 +49,7 @@ steps:
|
||||
|
||||
- label: PyTorch Fullgraph
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
@@ -60,6 +61,7 @@ steps:
|
||||
# if this test fails, it means the nightly torch version is not compatible with some
|
||||
# of the dependencies. Please check the error message and add the package to whitelist
|
||||
# in /vllm/tools/pre_commit/generate_nightly_torch_test.py
|
||||
device: h200_18gb
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- requirements/nightly_torch_test.txt
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
group: Quantization
|
||||
depends_on:
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: Quantization
|
||||
@@ -16,7 +16,7 @@ steps:
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
# TODO(jerryzh168): resolve the above comment
|
||||
- uv pip install --system torchao==0.14.1 --index-url https://download.pytorch.org/whl/cu129
|
||||
- uv pip install --system torchao==0.17.0 --index-url https://download.pytorch.org/whl/cu130
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ steps:
|
||||
# If this fails, it means the PR introduces a dependency that
|
||||
# conflicts with Ray's dependency constraints.
|
||||
# See https://github.com/vllm-project/vllm/issues/33599
|
||||
device: h200_18gb
|
||||
soft_fail: true
|
||||
timeout_in_minutes: 10
|
||||
source_file_dependencies:
|
||||
|
||||
@@ -4,6 +4,18 @@ depends_on:
|
||||
steps:
|
||||
- label: Spec Decode Eagle
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/v1/spec_decode/
|
||||
- vllm/v1/worker/gpu/spec_decode/
|
||||
- tests/v1/e2e/spec_decode/
|
||||
commands:
|
||||
- pytest -v -s v1/e2e/spec_decode -k "eagle_correctness"
|
||||
|
||||
- label: Spec Decode Eagle Nightly B200
|
||||
timeout_in_minutes: 30
|
||||
device: b200
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/v1/spec_decode/
|
||||
- vllm/v1/worker/gpu/spec_decode/
|
||||
@@ -13,6 +25,7 @@ steps:
|
||||
|
||||
- label: Spec Decode Speculators + MTP
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/v1/spec_decode/
|
||||
- vllm/v1/worker/gpu/spec_decode/
|
||||
@@ -21,8 +34,21 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s v1/e2e/spec_decode -k "speculators or mtp_correctness"
|
||||
|
||||
- label: Spec Decode Speculators + MTP Nightly B200
|
||||
timeout_in_minutes: 30
|
||||
device: b200
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/v1/spec_decode/
|
||||
- vllm/v1/worker/gpu/spec_decode/
|
||||
- vllm/transformers_utils/configs/speculators/
|
||||
- tests/v1/e2e/spec_decode/
|
||||
commands:
|
||||
- pytest -v -s v1/e2e/spec_decode -k "speculators or mtp_correctness"
|
||||
|
||||
- label: Spec Decode Ngram + Suffix
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/v1/spec_decode/
|
||||
- vllm/v1/worker/gpu/spec_decode/
|
||||
@@ -32,6 +58,18 @@ steps:
|
||||
|
||||
- label: Spec Decode Draft Model
|
||||
timeout_in_minutes: 30
|
||||
device: h200_18gb
|
||||
source_file_dependencies:
|
||||
- vllm/v1/spec_decode/
|
||||
- vllm/v1/worker/gpu/spec_decode/
|
||||
- tests/v1/e2e/spec_decode/
|
||||
commands:
|
||||
- pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference"
|
||||
|
||||
- label: Spec Decode Draft Model Nightly B200
|
||||
timeout_in_minutes: 30
|
||||
device: b200
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/v1/spec_decode/
|
||||
- vllm/v1/worker/gpu/spec_decode/
|
||||
|
||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@@ -13,6 +13,9 @@
|
||||
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
|
||||
/vllm/model_executor/model_loader @22quinn
|
||||
/vllm/model_executor/layers/batch_invariant.py @yewentao256
|
||||
/vllm/ir @ProExpertProg
|
||||
/vllm/kernels/ @ProExpertProg @tjtanaa
|
||||
/vllm/kernels/helion @ProExpertProg @zou3519
|
||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
|
||||
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
|
||||
/tests/evals @mgoin @vadiklyutiy
|
||||
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/kernels/ir @ProExpertProg @tjtanaa
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
|
||||
|
||||
11
.github/mergify.yml
vendored
11
.github/mergify.yml
vendored
@@ -18,7 +18,7 @@ pull_request_rules:
|
||||
- name: comment-pre-commit-failure
|
||||
description: Comment on PR when pre-commit check fails
|
||||
conditions:
|
||||
- status-failure=pre-commit
|
||||
- check-failure=pre-commit
|
||||
- -closed
|
||||
- -draft
|
||||
actions:
|
||||
@@ -51,7 +51,7 @@ pull_request_rules:
|
||||
- name: comment-dco-failure
|
||||
description: Comment on PR when DCO check fails
|
||||
conditions:
|
||||
- status-failure=dco
|
||||
- check-failure=dco
|
||||
- -closed
|
||||
- -draft
|
||||
actions:
|
||||
@@ -378,17 +378,18 @@ pull_request_rules:
|
||||
add:
|
||||
- tool-calling
|
||||
|
||||
- name: auto-rebase if approved, ready, and 40 commits behind main
|
||||
- name: auto-rebase to keep merge candidate within 1 day behind main
|
||||
conditions:
|
||||
- base = main
|
||||
- label=ready
|
||||
- "#approved-reviews-by >= 1"
|
||||
- "#commits-behind >= 40"
|
||||
- "#commits-behind >= 50"
|
||||
- "#check-failure = 0"
|
||||
- -closed
|
||||
- -draft
|
||||
- -conflict
|
||||
actions:
|
||||
rebase: {}
|
||||
update: {}
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
|
||||
7
.github/workflows/pre-commit.yml
vendored
7
.github/workflows/pre-commit.yml
vendored
@@ -28,6 +28,7 @@ jobs:
|
||||
});
|
||||
|
||||
const hasReadyLabel = pr.labels.some(l => l.name === 'ready');
|
||||
const hasVerifiedLabel = pr.labels.some(l => l.name === 'verified');
|
||||
|
||||
const { data: mergedPRs } = await github.rest.search.issuesAndPullRequests({
|
||||
q: `repo:${context.repo.owner}/${context.repo.repo} is:pr is:merged author:${pr.user.login}`,
|
||||
@@ -35,10 +36,10 @@ jobs:
|
||||
});
|
||||
const mergedCount = mergedPRs.total_count;
|
||||
|
||||
if (hasReadyLabel || mergedCount >= 4) {
|
||||
core.info(`Check passed: ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
|
||||
if (hasReadyLabel || hasVerifiedLabel || mergedCount >= 4) {
|
||||
core.info(`Check passed: verified label=${hasVerifiedLabel}, ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
|
||||
} else {
|
||||
core.setFailed(`PR must have the 'ready' label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
|
||||
core.setFailed(`PR must have the 'verified' or 'ready' (which also triggers tests) label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
|
||||
}
|
||||
|
||||
pre-commit:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -12,6 +12,9 @@ vllm/third_party/triton_kernels/*
|
||||
# FlashMLA interface copied from source
|
||||
vllm/third_party/flashmla/flash_mla_interface.py
|
||||
|
||||
# DeepGEMM vendored package built from source
|
||||
vllm/third_party/deep_gemm/
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ repos:
|
||||
rev: 0.11.1
|
||||
hooks:
|
||||
- id: pip-compile
|
||||
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
|
||||
args: [requirements/test.in, -c, requirements/common.txt, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu130, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- id: pip-compile
|
||||
alias: pip-compile-rocm
|
||||
@@ -59,21 +59,54 @@ repos:
|
||||
--no-emit-package, cuda-pathfinder,
|
||||
--no-emit-package, cuda-toolkit,
|
||||
--no-emit-package, cupy-cuda12x,
|
||||
# nvidia packages (unsuffixed / unified naming)
|
||||
--no-emit-package, nvidia-cublas,
|
||||
--no-emit-package, nvidia-cuda-cupti,
|
||||
--no-emit-package, nvidia-cuda-nvrtc,
|
||||
--no-emit-package, nvidia-cuda-runtime,
|
||||
--no-emit-package, nvidia-cudnn-cu13,
|
||||
--no-emit-package, nvidia-cudnn,
|
||||
--no-emit-package, nvidia-cufft,
|
||||
--no-emit-package, nvidia-cufile,
|
||||
--no-emit-package, nvidia-curand,
|
||||
--no-emit-package, nvidia-cusolver,
|
||||
--no-emit-package, nvidia-cusparse,
|
||||
--no-emit-package, nvidia-cusparselt,
|
||||
--no-emit-package, nvidia-nccl,
|
||||
--no-emit-package, nvidia-nvjitlink,
|
||||
--no-emit-package, nvidia-nvshmem,
|
||||
--no-emit-package, nvidia-nvtx,
|
||||
# nvidia cu12 packages
|
||||
--no-emit-package, nvidia-cublas-cu12,
|
||||
--no-emit-package, nvidia-cuda-cupti-cu12,
|
||||
--no-emit-package, nvidia-cuda-nvrtc-cu12,
|
||||
--no-emit-package, nvidia-cuda-runtime-cu12,
|
||||
--no-emit-package, nvidia-cudnn-cu12,
|
||||
--no-emit-package, nvidia-cufft-cu12,
|
||||
--no-emit-package, nvidia-cufile-cu12,
|
||||
--no-emit-package, nvidia-curand-cu12,
|
||||
--no-emit-package, nvidia-cusolver-cu12,
|
||||
--no-emit-package, nvidia-cusparse-cu12,
|
||||
--no-emit-package, nvidia-cusparselt-cu12,
|
||||
--no-emit-package, nvidia-nccl-cu12,
|
||||
--no-emit-package, nvidia-nvjitlink-cu12,
|
||||
--no-emit-package, nvidia-nvshmem-cu12,
|
||||
--no-emit-package, nvidia-nvtx-cu12,
|
||||
# nvidia cu13 packages
|
||||
--no-emit-package, nvidia-cublas-cu13,
|
||||
--no-emit-package, nvidia-cuda-cupti-cu13,
|
||||
--no-emit-package, nvidia-cuda-nvrtc-cu13,
|
||||
--no-emit-package, nvidia-cuda-runtime-cu13,
|
||||
--no-emit-package, nvidia-cudnn-cu13,
|
||||
--no-emit-package, nvidia-cufft-cu13,
|
||||
--no-emit-package, nvidia-cufile-cu13,
|
||||
--no-emit-package, nvidia-curand-cu13,
|
||||
--no-emit-package, nvidia-cusolver-cu13,
|
||||
--no-emit-package, nvidia-cusparse-cu13,
|
||||
--no-emit-package, nvidia-cusparselt-cu13,
|
||||
--no-emit-package, nvidia-nccl-cu13,
|
||||
--no-emit-package, nvidia-nvjitlink,
|
||||
--no-emit-package, nvidia-nvjitlink-cu13,
|
||||
--no-emit-package, nvidia-nvshmem-cu13,
|
||||
--no-emit-package, nvidia-nvtx,
|
||||
--no-emit-package, nvidia-nvtx-cu13,
|
||||
]
|
||||
files: ^requirements/rocm-test\.(in|txt)$
|
||||
- repo: local
|
||||
|
||||
189
CMakeLists.txt
189
CMakeLists.txt
@@ -56,8 +56,8 @@ endif()
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from docker/Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.10.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.10.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.11.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.11.0")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
@@ -225,8 +225,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
# Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates
|
||||
# a lot of warnings that always mask real issues. Suppressing until this is properly addressed.
|
||||
#
|
||||
set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")
|
||||
set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result -Wno-unused-value")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result -Wno-unused-value")
|
||||
endif()
|
||||
|
||||
#
|
||||
@@ -299,6 +299,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/w8a8/int8/scaled_quant.cu"
|
||||
"csrc/quantization/w8a8/fp8/common.cu"
|
||||
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
@@ -340,8 +341,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
@@ -489,59 +488,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# FP4 Archs and flags
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# CUTLASS MLA Archs and flags
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
@@ -681,34 +627,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
|
||||
)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
|
||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND W4A8_ARCHS)
|
||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Hadacore kernels
|
||||
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
@@ -760,7 +678,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(VLLM_STABLE_EXT_SRC
|
||||
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu")
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
@@ -978,6 +899,96 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
|
||||
#
|
||||
|
||||
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# FP4 Archs and flags
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# W4A8 kernels (moved from _C to _C_stable_libtorch)
|
||||
#
|
||||
|
||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
|
||||
)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
|
||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND W4A8_ARCHS)
|
||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C_stable extension.")
|
||||
define_extension_target(
|
||||
_C_stable_libtorch
|
||||
@@ -1019,7 +1030,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_wna16.cu"
|
||||
"csrc/moe/grouped_topk_kernels.cu"
|
||||
"csrc/moe/gpt_oss_router_gemm.cu"
|
||||
"csrc/moe/router_gemm.cu")
|
||||
endif()
|
||||
|
||||
@@ -1212,6 +1222,7 @@ endif()
|
||||
|
||||
# For CUDA we also build and ship some external projects.
|
||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(cmake/external_projects/deepgemm.cmake)
|
||||
include(cmake/external_projects/flashmla.cmake)
|
||||
include(cmake/external_projects/qutlass.cmake)
|
||||
|
||||
|
||||
45
README.md
45
README.md
@@ -23,47 +23,54 @@ For events, please visit [vllm.ai/events](https://vllm.ai/events) to join us.
|
||||
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
|
||||
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has grown into one of the most active open-source AI projects built and maintained by a diverse community of many dozens of academic institutions and companies from over 2000 contributors.
|
||||
|
||||
vLLM is fast with:
|
||||
|
||||
- State-of-the-art serving throughput
|
||||
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
|
||||
- Continuous batching of incoming requests
|
||||
- Fast model execution with CUDA/HIP graph
|
||||
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
|
||||
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
|
||||
- Speculative decoding
|
||||
- Chunked prefill
|
||||
- Continuous batching of incoming requests, chunked prefill, prefix caching
|
||||
- Fast and flexible model execution with piecewise and full CUDA/HIP graphs
|
||||
- Quantization: FP8, MXFP8/MXFP4, NVFP4, INT8, INT4, GPTQ/AWQ, GGUF, compressed-tensors, ModelOpt, TorchAO, and [more](https://docs.vllm.ai/en/latest/features/quantization/index.html)
|
||||
- Optimized attention kernels including FlashAttention, FlashInfer, TRTLLM-GEN, FlashMLA, and Triton
|
||||
- Optimized GEMM/MoE kernels for various precisions using CUTLASS, TRTLLM-GEN, CuTeDSL
|
||||
- Speculative decoding including n-gram, suffix, EAGLE, DFlash
|
||||
- Automatic kernel generation and graph-level transformations using torch.compile
|
||||
- Disaggregated prefill, decode, and encode
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Tensor, pipeline, data, expert, and context parallelism for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
|
||||
- Prefix caching support
|
||||
- Multi-LoRA support
|
||||
- Generation of structured outputs using xgrammar or guidance
|
||||
- Tool calling and reasoning parsers
|
||||
- OpenAI-compatible API server, plus Anthropic Messages API and gRPC support
|
||||
- Efficient multi-LoRA support for dense and MoE layers
|
||||
- Support for NVIDIA GPUs, AMD GPUs, and x86/ARM/PowerPC CPUs. Additionally, diverse hardware plugins such as Google TPUs, Intel Gaudi, IBM Spyre, Huawei Ascend, Rebellions NPU, Apple Silicon, MetaX GPU, and more.
|
||||
|
||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||
vLLM seamlessly supports 200+ model architectures on HuggingFace, including:
|
||||
|
||||
- Transformer-like LLMs (e.g., Llama)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
|
||||
- Embedding Models (e.g., E5-Mistral)
|
||||
- Multi-modal LLMs (e.g., LLaVA)
|
||||
- Decoder-only LLMs (e.g., Llama, Qwen, Gemma)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral, DeepSeek-V3, Qwen-MoE, GPT-OSS)
|
||||
- Hybrid attention and state-space models (e.g., Mamba, Qwen3.5)
|
||||
- Multi-modal models (e.g., LLaVA, Qwen-VL, Pixtral)
|
||||
- Embedding and retrieval models (e.g., E5-Mistral, GTE, ColBERT)
|
||||
- Reward and classification models (e.g., Qwen-Math)
|
||||
|
||||
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||
|
||||
## Getting Started
|
||||
|
||||
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
|
||||
Install vLLM with [`uv`](https://docs.astral.sh/uv/) (recommended) or `pip`:
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
uv pip install vllm
|
||||
```
|
||||
|
||||
Or [build from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source) for development.
|
||||
|
||||
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
||||
|
||||
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
|
||||
|
||||
264
benchmarks/fused_kernels/merge_attn_states_benchmarks.py
Normal file
264
benchmarks/fused_kernels/merge_attn_states_benchmarks.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark: Fused FP8 output quantization in merge_attn_states
|
||||
|
||||
Compares fused vs unfused approaches for producing FP8-quantized merged
|
||||
attention output:
|
||||
1. Fused CUDA -- single CUDA kernel (merge + FP8 quant)
|
||||
2. Fused Triton -- single Triton kernel (merge + FP8 quant)
|
||||
3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant
|
||||
4. Unfused Triton -- Triton merge + torch.compiled FP8 quant
|
||||
|
||||
Usage:
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
|
||||
from vllm.benchmarks.lib.utils import default_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states as merge_attn_states_triton,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096]
|
||||
|
||||
# (label, num_heads, head_size) — num_heads is for TP=1
|
||||
HEAD_CONFIGS = [
|
||||
("DeepSeek-V3 MLA", 128, 128),
|
||||
("Llama-70B", 64, 128),
|
||||
("Llama-8B", 32, 128),
|
||||
]
|
||||
|
||||
TP_SIZES = [1, 2, 4, 8]
|
||||
|
||||
INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
QUANTILES = [0.5, 0.2, 0.8]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def short_dtype(dtype: torch.dtype) -> str:
|
||||
return str(dtype).removeprefix("torch.")
|
||||
|
||||
|
||||
def make_inputs(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Create random prefix/suffix outputs and LSEs."""
|
||||
prefix_output = torch.randn(
|
||||
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
|
||||
)
|
||||
prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
|
||||
suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
|
||||
# Sprinkle some inf values to exercise edge-case paths
|
||||
mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
|
||||
prefix_lse[mask] = float("inf")
|
||||
mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
|
||||
suffix_lse[mask2] = float("inf")
|
||||
return prefix_output, suffix_output, prefix_lse, suffix_lse
|
||||
|
||||
|
||||
def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes):
|
||||
"""Build (num_tokens, num_heads, head_size, dtype_str) config tuples,
|
||||
applying TP division to num_heads and skipping invalid combos."""
|
||||
configs = []
|
||||
for (_, nh, hs), nt, dtype, tp in itertools.product(
|
||||
head_configs, num_tokens_list, input_dtypes, tp_sizes
|
||||
):
|
||||
nh_tp = nh // tp
|
||||
if nh_tp >= 1:
|
||||
configs.append((nt, nh_tp, hs, short_dtype(dtype)))
|
||||
return configs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark merge_attn_states fused FP8 quantization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help=f"Override token counts (default: {NUM_TOKENS_LIST})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Input dtypes (e.g. bfloat16 float16 float32). "
|
||||
f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse args and build configs before decorators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
args = parse_args()
|
||||
|
||||
num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST
|
||||
tp_sizes = args.tp if args.tp else TP_SIZES
|
||||
|
||||
if args.dtype:
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype]
|
||||
else:
|
||||
input_dtypes = INPUT_DTYPES
|
||||
|
||||
configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes)
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_heads", "head_size", "dtype_str"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"],
|
||||
line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")],
|
||||
ylabel="us",
|
||||
plot_name="merge_attn_states FP8 (fused vs unfused)",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
@default_vllm_config()
|
||||
def benchmark(num_tokens, num_heads, head_size, dtype_str, provider):
|
||||
input_dtype = getattr(torch, dtype_str)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs(
|
||||
num_tokens, num_heads, head_size, input_dtype
|
||||
)
|
||||
output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda")
|
||||
|
||||
if provider == "fused_cuda":
|
||||
output = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
|
||||
)
|
||||
fn = lambda: merge_attn_states_cuda(
|
||||
output,
|
||||
prefix_out,
|
||||
prefix_lse,
|
||||
suffix_out,
|
||||
suffix_lse,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
elif provider == "fused_triton":
|
||||
output = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
|
||||
)
|
||||
fn = lambda: merge_attn_states_triton(
|
||||
output,
|
||||
prefix_out,
|
||||
prefix_lse,
|
||||
suffix_out,
|
||||
suffix_lse,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
elif provider == "unfused_cuda":
|
||||
merge_buf = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
quant_fp8 = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
column_major_scales=False,
|
||||
)
|
||||
quant_input = merge_buf.view(-1, head_size)
|
||||
compiled_quant = torch.compile(
|
||||
quant_fp8.forward_native, fullgraph=True, dynamic=False
|
||||
)
|
||||
|
||||
def unfused_fn():
|
||||
merge_attn_states_cuda(
|
||||
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
|
||||
)
|
||||
compiled_quant(quant_input, output_scale)
|
||||
|
||||
fn = unfused_fn
|
||||
else: # unfused_triton
|
||||
merge_buf = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
quant_fp8 = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
column_major_scales=False,
|
||||
)
|
||||
quant_input = merge_buf.view(-1, head_size)
|
||||
compiled_quant = torch.compile(
|
||||
quant_fp8.forward_native, fullgraph=True, dynamic=False
|
||||
)
|
||||
|
||||
def unfused_fn():
|
||||
merge_attn_states_triton(
|
||||
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
|
||||
)
|
||||
compiled_quant(quant_input, output_scale)
|
||||
|
||||
fn = unfused_fn
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES)
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
device_name = current_platform.get_device_name()
|
||||
print(f"Device: {device_name}")
|
||||
print(f"Token counts: {num_tokens_list}")
|
||||
print(f"TP sizes: {tp_sizes}")
|
||||
print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}")
|
||||
print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}")
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.inference_mode():
|
||||
main()
|
||||
211
benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py
Normal file
211
benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class bench_params_t:
|
||||
num_tokens: int
|
||||
hidden_size: int
|
||||
dtype: torch.dtype
|
||||
group_size: int # Changed from list[int] to int
|
||||
|
||||
def description(self):
|
||||
return (
|
||||
f"N {self.num_tokens} "
|
||||
f"x D {self.hidden_size} "
|
||||
f"x DT {self.dtype} "
|
||||
f"x GS {self.group_size}"
|
||||
)
|
||||
|
||||
|
||||
def get_bench_params() -> list[bench_params_t]:
|
||||
"""Test configurations covering common model sizes."""
|
||||
NUM_TOKENS = [16, 128, 512, 2048]
|
||||
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
|
||||
|
||||
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
|
||||
bench_params = list(
|
||||
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
|
||||
)
|
||||
return bench_params
|
||||
|
||||
|
||||
# Reference implementations
|
||||
def unfused_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int, # Changed from list[int]
|
||||
):
|
||||
"""Unfused: SiLU+Mul then per-tensor quantize."""
|
||||
hidden = x.shape[-1] // 2
|
||||
gate, up = x.split(hidden, dim=-1)
|
||||
|
||||
# SiLU(gate) * up
|
||||
silu_out = F.silu(gate) * up
|
||||
|
||||
# Per-tensor quantize (no group_size used here)
|
||||
silu_out, _ = ops.scaled_fp8_quant(silu_out)
|
||||
|
||||
|
||||
def unfused_groupwise_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int, # Changed from list[int]
|
||||
):
|
||||
"""Unfused: SiLU+Mul then group-wise quantize."""
|
||||
hidden = x.shape[-1] // 2
|
||||
gate, up = x.split(hidden, dim=-1)
|
||||
|
||||
# SiLU(gate) * up
|
||||
silu_out = F.silu(gate) * up
|
||||
|
||||
# Group quantize - use group_size directly
|
||||
silu_out, _ = per_token_group_quant_fp8(
|
||||
silu_out, group_size=group_size, use_ue8m0=False
|
||||
)
|
||||
|
||||
|
||||
def fused_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int,
|
||||
):
|
||||
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
|
||||
out, _ = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=group_size,
|
||||
quant_dtype=quant_dtype,
|
||||
is_scale_transposed=False,
|
||||
)
|
||||
|
||||
|
||||
# Bench functions
|
||||
def bench_fn(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
fn: Callable,
|
||||
description: str,
|
||||
) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"x": x,
|
||||
"quant_dtype": quant_dtype,
|
||||
"group_size": group_size,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(x, quant_dtype, group_size)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
|
||||
"""Run benchmarks for all implementations."""
|
||||
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
|
||||
scale = 1 / params.hidden_size
|
||||
x = (
|
||||
torch.randn(
|
||||
params.num_tokens,
|
||||
params.hidden_size * 2,
|
||||
dtype=params.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
* scale
|
||||
)
|
||||
|
||||
timers = []
|
||||
|
||||
# Unfused per-tensor FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_fp8_impl,
|
||||
"unfused_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# Unfused group-wise FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_groupwise_fp8_impl,
|
||||
"unfused_groupwise_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# Fused group-wise FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
fused_impl,
|
||||
"fused_groupwise_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_default_device("cuda")
|
||||
bench_params = get_bench_params()
|
||||
|
||||
print(f"Running {len(bench_params)} benchmark configurations...")
|
||||
print(
|
||||
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
|
||||
)
|
||||
print()
|
||||
|
||||
timers = []
|
||||
for bp in tqdm(bench_params):
|
||||
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
|
||||
timers.extend(result_timers)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("FINAL COMPARISON - ALL RESULTS")
|
||||
print("=" * 80)
|
||||
print_timers(timers)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -9,11 +9,12 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||
import torch
|
||||
|
||||
from vllm.benchmarks.lib.utils import default_vllm_config
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
@@ -70,11 +71,15 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
weight_group_shape = GroupShape(block_n, block_k)
|
||||
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||
|
||||
linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=weight_group_shape,
|
||||
act_quant_group_shape=act_quant_group_shape,
|
||||
cutlass_block_fp8_supported=use_cutlass,
|
||||
use_aiter_and_is_supported=False,
|
||||
linear_op = init_fp8_linear_kernel(
|
||||
weight_quant_key=create_fp8_quant_key(
|
||||
static=True, group_shape=weight_group_shape
|
||||
),
|
||||
activation_quant_key=create_fp8_quant_key(
|
||||
static=False, group_shape=act_quant_group_shape
|
||||
),
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name="build_w8a8_block_fp8_runner",
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
# Dimensions supported by the DSV3 specialized kernel
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
# Dimensions supported by the gpt-oss specialized kernel
|
||||
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
|
||||
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
|
||||
|
||||
|
||||
def get_batch_size_range(max_batch_size):
|
||||
return [2**x for x in range(14) if 2**x <= max_batch_size]
|
||||
|
||||
|
||||
def get_model_params(config):
|
||||
if config.architectures[0] in (
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
):
|
||||
num_experts = config.n_routed_experts
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] in ("GptOssForCausalLM",):
|
||||
num_experts = config.num_local_experts
|
||||
hidden_size = config.hidden_size
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {config.architectures}")
|
||||
return num_experts, hidden_size
|
||||
|
||||
|
||||
def get_benchmark(model, max_batch_size, trust_remote_code):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=get_batch_size_range(max_batch_size),
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"torch",
|
||||
"vllm",
|
||||
],
|
||||
line_names=["PyTorch", "vLLM"],
|
||||
styles=([("blue", "-"), ("red", "-")]),
|
||||
ylabel="TFLOPs",
|
||||
plot_name=f"{model} router gemm throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
config = get_config(model=model, trust_remote_code=trust_remote_code)
|
||||
num_experts, hidden_size = get_model_params(config)
|
||||
|
||||
mat_a = torch.randn(
|
||||
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
mat_b = torch.randn(
|
||||
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
bias = torch.randn(
|
||||
num_experts, dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
90
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
allow_dsv3_router_gemm = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
allow_gpt_oss_router_gemm = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
has_bias = False
|
||||
if allow_gpt_oss_router_gemm:
|
||||
has_bias = True
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
|
||||
def runner():
|
||||
if has_bias:
|
||||
F.linear(mat_a, mat_b, bias)
|
||||
else:
|
||||
F.linear(mat_a, mat_b)
|
||||
elif provider == "vllm":
|
||||
|
||||
def runner():
|
||||
if allow_dsv3_router_gemm:
|
||||
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
|
||||
elif allow_gpt_oss_router_gemm:
|
||||
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
|
||||
else:
|
||||
raise ValueError("Unsupported router gemm")
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
runner, quantiles=quantiles
|
||||
)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * batch_size * hidden_size * num_experts
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
|
||||
parser.add_argument("--max-batch-size", default=16, type=int)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the benchmark function
|
||||
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
@@ -20,7 +20,7 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.experts.batched_deep_gemm_moe import (
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
162
benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py
Normal file
162
benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Benchmarks the fused Triton bilinear position-embedding kernel against
|
||||
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
|
||||
#
|
||||
# == Usage Examples ==
|
||||
#
|
||||
# Default benchmark:
|
||||
# python3 benchmark_vit_bilinear_pos_embed.py
|
||||
#
|
||||
# Custom parameters:
|
||||
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
|
||||
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.qwen3_vl import (
|
||||
pos_embed_interpolate_native,
|
||||
triton_pos_embed_interpolate,
|
||||
)
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
# (h, w) configurations to benchmark
|
||||
h_w_configs = [
|
||||
(16, 16),
|
||||
(32, 32),
|
||||
(48, 48),
|
||||
(64, 64),
|
||||
(128, 128),
|
||||
(32, 48),
|
||||
(60, 80),
|
||||
]
|
||||
|
||||
# Temporal dimensions
|
||||
t_range = [1]
|
||||
|
||||
configs = list(itertools.product(t_range, h_w_configs))
|
||||
|
||||
|
||||
def get_benchmark(
|
||||
num_grid_per_side: int,
|
||||
spatial_merge_size: int,
|
||||
hidden_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["t", "h_w"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["native", "triton"],
|
||||
line_names=["Native (PyTorch)", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=(
|
||||
f"vit-bilinear-pos-embed-"
|
||||
f"grid{num_grid_per_side}-"
|
||||
f"dim{hidden_dim}-"
|
||||
f"{dtype}"
|
||||
),
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(t, h_w, provider):
|
||||
h, w = h_w
|
||||
|
||||
torch.manual_seed(42)
|
||||
embed_weight = (
|
||||
torch.randn(
|
||||
num_grid_per_side * num_grid_per_side,
|
||||
hidden_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
* 0.25
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "native":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: pos_embed_interpolate_native(
|
||||
embed_weight,
|
||||
t,
|
||||
h,
|
||||
w,
|
||||
num_grid_per_side,
|
||||
spatial_merge_size,
|
||||
dtype,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
assert HAS_TRITON, "Triton not available"
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: triton_pos_embed_interpolate(
|
||||
embed_weight,
|
||||
t,
|
||||
h,
|
||||
w,
|
||||
num_grid_per_side,
|
||||
spatial_merge_size,
|
||||
dtype,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark bilinear position embedding interpolation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-grid-per-side",
|
||||
type=int,
|
||||
default=48,
|
||||
help="Position embedding grid size (default: 48 for Qwen3-VL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spatial-merge-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Spatial merge size (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=1152,
|
||||
help="Embedding hidden dimension (default: 1152 for Qwen3-VL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
choices=["cuda:0", "cuda:1"],
|
||||
default="cuda:0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./vit_pos_embed/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dtype = torch.bfloat16
|
||||
|
||||
bench = get_benchmark(
|
||||
args.num_grid_per_side,
|
||||
args.spatial_merge_size,
|
||||
args.hidden_dim,
|
||||
dtype,
|
||||
args.device,
|
||||
)
|
||||
bench.run(print_data=True, save_path=args.save_path)
|
||||
151
cmake/external_projects/deepgemm.cmake
Normal file
151
cmake/external_projects/deepgemm.cmake
Normal file
@@ -0,0 +1,151 @@
|
||||
include(FetchContent)
|
||||
|
||||
# If DEEPGEMM_SRC_DIR is set, DeepGEMM is built from that directory
|
||||
# instead of downloading.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{DEEPGEMM_SRC_DIR})
|
||||
set(DEEPGEMM_SRC_DIR $ENV{DEEPGEMM_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(DEEPGEMM_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
deepgemm
|
||||
SOURCE_DIR ${DEEPGEMM_SRC_DIR}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
else()
|
||||
# This ref should be kept in sync with tools/install_deepgemm.sh
|
||||
FetchContent_Declare(
|
||||
deepgemm
|
||||
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git
|
||||
GIT_TAG 477618cd51baffca09c4b0b87e97c03fe827ef03
|
||||
GIT_SUBMODULES "third-party/cutlass" "third-party/fmt"
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
endif()
|
||||
|
||||
# Use FetchContent_Populate (not MakeAvailable) to avoid processing
|
||||
# DeepGEMM's own CMakeLists.txt which has incompatible find_package calls.
|
||||
FetchContent_GetProperties(deepgemm)
|
||||
if(NOT deepgemm_POPULATED)
|
||||
FetchContent_Populate(deepgemm)
|
||||
endif()
|
||||
message(STATUS "DeepGEMM is available at ${deepgemm_SOURCE_DIR}")
|
||||
|
||||
# DeepGEMM requires CUDA 12.3+ for SM90, 12.9+ for SM100
|
||||
set(DEEPGEMM_SUPPORT_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND DEEPGEMM_SUPPORT_ARCHS "9.0a")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
|
||||
list(APPEND DEEPGEMM_SUPPORT_ARCHS "10.0f")
|
||||
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND DEEPGEMM_SUPPORT_ARCHS "10.0a")
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(DEEPGEMM_ARCHS
|
||||
"${DEEPGEMM_SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||
|
||||
if(DEEPGEMM_ARCHS)
|
||||
message(STATUS "DeepGEMM CUDA architectures: ${DEEPGEMM_ARCHS}")
|
||||
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
#
|
||||
# Build the _C pybind11 extension from DeepGEMM's C++ source.
|
||||
# This is a CXX-only module — CUDA kernels are JIT-compiled at runtime.
|
||||
#
|
||||
Python_add_library(_deep_gemm_C MODULE WITH_SOABI
|
||||
"${deepgemm_SOURCE_DIR}/csrc/python_api.cpp")
|
||||
|
||||
# The pybind11 module name must be _C to match DeepGEMM's Python imports.
|
||||
set_target_properties(_deep_gemm_C PROPERTIES OUTPUT_NAME "_C")
|
||||
|
||||
target_compile_definitions(_deep_gemm_C PRIVATE
|
||||
"-DTORCH_EXTENSION_NAME=_C")
|
||||
|
||||
target_include_directories(_deep_gemm_C PRIVATE
|
||||
"${deepgemm_SOURCE_DIR}/csrc"
|
||||
"${deepgemm_SOURCE_DIR}/deep_gemm/include"
|
||||
"${deepgemm_SOURCE_DIR}/third-party/cutlass/include"
|
||||
"${deepgemm_SOURCE_DIR}/third-party/cutlass/tools/util/include"
|
||||
"${deepgemm_SOURCE_DIR}/third-party/fmt/include")
|
||||
|
||||
target_compile_options(_deep_gemm_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=c++17>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-O3>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-Wno-psabi>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-Wno-deprecated-declarations>)
|
||||
|
||||
# torch_python is required because DeepGEMM uses pybind11 type casters
|
||||
# for at::Tensor (via PYBIND11_MODULE), unlike vLLM's own extensions which
|
||||
# use torch::Library custom ops.
|
||||
find_library(TORCH_PYTHON_LIBRARY torch_python
|
||||
PATHS "${TORCH_INSTALL_PREFIX}/lib"
|
||||
REQUIRED)
|
||||
|
||||
target_link_libraries(_deep_gemm_C PRIVATE
|
||||
torch ${TORCH_LIBRARIES} "${TORCH_PYTHON_LIBRARY}"
|
||||
CUDA::cudart CUDA::nvrtc)
|
||||
|
||||
# Install the shared library into the vendored package directory
|
||||
install(TARGETS _deep_gemm_C
|
||||
LIBRARY DESTINATION vllm/third_party/deep_gemm
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
#
|
||||
# Vendor DeepGEMM Python package files
|
||||
#
|
||||
install(FILES
|
||||
"${deepgemm_SOURCE_DIR}/deep_gemm/__init__.py"
|
||||
DESTINATION vllm/third_party/deep_gemm
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/utils/"
|
||||
DESTINATION vllm/third_party/deep_gemm/utils
|
||||
COMPONENT _deep_gemm_C
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/testing/"
|
||||
DESTINATION vllm/third_party/deep_gemm/testing
|
||||
COMPONENT _deep_gemm_C
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/legacy/"
|
||||
DESTINATION vllm/third_party/deep_gemm/legacy
|
||||
COMPONENT _deep_gemm_C
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
|
||||
# Generate envs.py (normally generated by DeepGEMM's setup.py build step)
|
||||
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
|
||||
"# Pre-installed environment variables\npersistent_envs = dict()\n")
|
||||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
|
||||
DESTINATION vllm/third_party/deep_gemm
|
||||
RENAME envs.py
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
#
|
||||
# Install include files needed for JIT compilation at runtime.
|
||||
# The JIT compiler finds these relative to the package directory.
|
||||
#
|
||||
|
||||
# DeepGEMM's own CUDA headers
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/include/"
|
||||
DESTINATION vllm/third_party/deep_gemm/include
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
# CUTLASS and CuTe headers (vendored for JIT, separate from vLLM's CUTLASS)
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/third-party/cutlass/include/"
|
||||
DESTINATION vllm/third_party/deep_gemm/include
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
else()
|
||||
message(STATUS "DeepGEMM will not compile: "
|
||||
"unsupported CUDA architecture ${CUDA_ARCHS}")
|
||||
# Create empty target so setup.py doesn't fail on unsupported systems
|
||||
add_custom_target(_deep_gemm_C)
|
||||
endif()
|
||||
@@ -39,7 +39,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 29210221863736a08f71a866459e368ad1ac4a95
|
||||
GIT_TAG f5bc33cfc02c744d24a2e9d50e6db656de40611c
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
@@ -87,18 +87,30 @@ endforeach()
|
||||
#
|
||||
add_custom_target(_vllm_fa4_cutedsl_C)
|
||||
|
||||
# Copy flash_attn/cute directory (needed for FA4) and transform imports
|
||||
# The cute directory uses flash_attn.cute imports internally, which we replace
|
||||
# with vllm.vllm_flash_attn.cute to match our package structure.
|
||||
install(CODE "
|
||||
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
|
||||
foreach(SRC_FILE \${CUTE_PY_FILES})
|
||||
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
|
||||
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
|
||||
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
|
||||
file(MAKE_DIRECTORY \${DST_DIR})
|
||||
file(READ \${SRC_FILE} FILE_CONTENTS)
|
||||
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
|
||||
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
|
||||
endforeach()
|
||||
" COMPONENT _vllm_fa4_cutedsl_C)
|
||||
# Install flash_attn/cute directory (needed for FA4).
|
||||
# When using a local source dir (VLLM_FLASH_ATTN_SRC_DIR), create a symlink
|
||||
# so edits to cute-dsl Python files take effect immediately without rebuilding.
|
||||
# Otherwise, copy files and transform flash_attn.cute imports to
|
||||
# vllm.vllm_flash_attn.cute to match our package structure.
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
install(CODE "
|
||||
set(LINK_TARGET \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\")
|
||||
set(LINK_NAME \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute\")
|
||||
file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")
|
||||
file(REMOVE_RECURSE \"\${LINK_NAME}\")
|
||||
file(CREATE_LINK \"\${LINK_TARGET}\" \"\${LINK_NAME}\" SYMBOLIC)
|
||||
" COMPONENT _vllm_fa4_cutedsl_C)
|
||||
else()
|
||||
install(CODE "
|
||||
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
|
||||
foreach(SRC_FILE \${CUTE_PY_FILES})
|
||||
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
|
||||
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
|
||||
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
|
||||
file(MAKE_DIRECTORY \${DST_DIR})
|
||||
file(READ \${SRC_FILE} FILE_CONTENTS)
|
||||
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
|
||||
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
|
||||
endforeach()
|
||||
" COMPONENT _vllm_fa4_cutedsl_C)
|
||||
endif()
|
||||
|
||||
@@ -3,22 +3,33 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "../quantization/w8a8/fp8/common.cuh"
|
||||
#include "../dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
template <typename scalar_t, typename output_t, const uint NUM_THREADS,
|
||||
bool USE_FP8_OUTPUT>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
output_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size, const uint prefix_head_stride,
|
||||
const uint output_head_stride) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint output_head_stride, const uint prefix_num_tokens,
|
||||
const float* output_scale) {
|
||||
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
|
||||
// Outputs store pack_size elements of output_t, which is smaller for FP8.
|
||||
using input_pack_t = uint4;
|
||||
using output_pack_t =
|
||||
std::conditional_t<USE_FP8_OUTPUT,
|
||||
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
|
||||
uint4>;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
@@ -41,8 +52,45 @@ __global__ void merge_attn_states_kernel(
|
||||
head_idx * output_head_stride;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
|
||||
scalar_t* output_head_ptr = output + dst_head_offset;
|
||||
output_t* output_head_ptr = output + dst_head_offset;
|
||||
|
||||
// Pre-invert scale: multiplication is faster than division
|
||||
float fp8_scale_inv = 1.0f;
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
fp8_scale_inv = 1.0f / *output_scale;
|
||||
}
|
||||
|
||||
// If token_idx >= prefix_num_tokens, just copy from suffix
|
||||
if (token_idx >= prefix_num_tokens) {
|
||||
if (pack_offset < head_size) {
|
||||
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
const float val =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
o_out_pack[i] =
|
||||
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
|
||||
}
|
||||
}
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
output_lse[head_idx * num_tokens + token_idx] = s_lse;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// For tokens within prefix range, merge prefix and suffix
|
||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||
@@ -53,20 +101,34 @@ __global__ void merge_attn_states_kernel(
|
||||
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
|
||||
continuing the pipeline then yields NaN. Root cause: with chunked prefill
|
||||
a batch may be split into two chunks; if a request in that batch has no
|
||||
prefix hit, every LSE entry for that request’s position is -inf, and at
|
||||
prefix hit, every LSE entry for that request's position is -inf, and at
|
||||
this moment we merge cross-attention at first. For now we simply emit
|
||||
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
|
||||
this problem.
|
||||
*/
|
||||
if (std::isinf(max_lse)) {
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
p_out_pack;
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
// Convert prefix values to FP8 (since -inf means no data,
|
||||
// prefix_output is expected to be zeros)
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
const float val =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
o_out_pack[i] =
|
||||
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
|
||||
}
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
@@ -84,30 +146,43 @@ __global__ void merge_attn_states_kernel(
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
// Compute merged values in float32
|
||||
float o_out_f[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
o_out_pack;
|
||||
// Convert and store
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
|
||||
o_out_f[i], fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
output_pack_t o_out_pack;
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
|
||||
o_out_f[i]);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
|
||||
}
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
@@ -134,50 +209,73 @@ __global__ void merge_attn_states_kernel(
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
|
||||
USE_FP8_OUTPUT) \
|
||||
{ \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \
|
||||
USE_FP8_OUTPUT> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride); \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride, \
|
||||
prefix_num_tokens, output_scale_ptr); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
* @param prefill_tokens_with_context Number of prefill tokens with context
|
||||
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
|
||||
* is computed by merging prefix_output and suffix_output. For remaining tokens
|
||||
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
|
||||
* from suffix_output.
|
||||
* @param output_scale Optional scalar tensor for FP8 static quantization.
|
||||
* When provided, output must be FP8 dtype.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
void merge_attn_states_launcher(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint prefix_head_stride = prefix_output.stride(1);
|
||||
const uint output_head_stride = output.stride(1);
|
||||
// Thread mapping is based on input BF16 pack_size
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
|
||||
const uint prefix_num_tokens =
|
||||
prefill_tokens_with_context.has_value()
|
||||
? static_cast<uint>(prefill_tokens_with_context.value())
|
||||
: num_tokens;
|
||||
TORCH_CHECK(prefix_num_tokens <= num_tokens,
|
||||
"prefix_num_tokens must be <= num_tokens");
|
||||
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
}
|
||||
float* output_scale_ptr = nullptr;
|
||||
if (output_scale.has_value()) {
|
||||
output_scale_ptr = output_scale.value().data_ptr<float>();
|
||||
}
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
@@ -189,14 +287,22 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
if (output_scale.has_value()) {
|
||||
// FP8 output path - dispatch on output FP8 type
|
||||
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
|
||||
});
|
||||
} else {
|
||||
// Original BF16/FP16/FP32 output path
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
|
||||
prefix_lse, suffix_output, \
|
||||
suffix_lse); \
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>( \
|
||||
output, output_lse, prefix_output, prefix_lse, suffix_output, \
|
||||
suffix_lse, prefill_tokens_with_context, output_scale); \
|
||||
}
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
@@ -204,6 +310,21 @@ void merge_attn_states(torch::Tensor& output,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
const torch::Tensor& suffix_lse,
|
||||
std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
if (output_scale.has_value()) {
|
||||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
|
||||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
|
||||
"output must be FP8 when output_scale is provided, got: ",
|
||||
output.scalar_type());
|
||||
} else {
|
||||
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
|
||||
"output dtype (", output.scalar_type(),
|
||||
") must match prefix_output dtype (",
|
||||
prefix_output.scalar_type(), ") when output_scale is not set");
|
||||
}
|
||||
// Always dispatch on prefix_output (input) dtype
|
||||
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
|
||||
CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
int64_t block_size_in_bytes,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
const torch::Tensor& dst_ptrs,
|
||||
const torch::Tensor& sizes);
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#if defined(__gfx942__)
|
||||
@@ -73,6 +75,68 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
}
|
||||
}
|
||||
|
||||
void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
const torch::Tensor& dst_ptrs,
|
||||
const torch::Tensor& sizes) {
|
||||
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
|
||||
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
|
||||
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
|
||||
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
|
||||
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
|
||||
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
|
||||
|
||||
const int64_t n = src_ptrs.size(0);
|
||||
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
|
||||
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
|
||||
|
||||
if (n == 0) return;
|
||||
|
||||
int64_t* src_data = src_ptrs.mutable_data_ptr<int64_t>();
|
||||
int64_t* dst_data = dst_ptrs.mutable_data_ptr<int64_t>();
|
||||
int64_t* size_data = sizes.mutable_data_ptr<int64_t>();
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
|
||||
// driver call, amortizing per-copy submission overhead.
|
||||
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
|
||||
// so we reinterpret_cast the tensor data directly to avoid copies.
|
||||
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
|
||||
static_assert(sizeof(size_t) == sizeof(int64_t));
|
||||
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
|
||||
CUmemcpyAttributes attr = {};
|
||||
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
|
||||
size_t attrs_idx = 0;
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 13000
|
||||
CUresult result = cuMemcpyBatchAsync(
|
||||
reinterpret_cast<CUdeviceptr*>(dst_data),
|
||||
reinterpret_cast<CUdeviceptr*>(src_data),
|
||||
reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), &attr,
|
||||
&attrs_idx, 1, static_cast<CUstream>(stream));
|
||||
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed with error ",
|
||||
result);
|
||||
#else
|
||||
size_t fail_idx = 0;
|
||||
CUresult result = cuMemcpyBatchAsync(
|
||||
reinterpret_cast<CUdeviceptr*>(dst_data),
|
||||
reinterpret_cast<CUdeviceptr*>(src_data),
|
||||
reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), &attr,
|
||||
&attrs_idx, 1, &fail_idx, static_cast<CUstream>(stream));
|
||||
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
|
||||
fail_idx, " with error ", result);
|
||||
#endif
|
||||
#else
|
||||
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
|
||||
// cudaMemcpyDefault lets the driver infer direction from pointer types.
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
|
||||
reinterpret_cast<void*>(src_data[i]),
|
||||
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
|
||||
stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Grid: (num_layers, num_pairs)
|
||||
|
||||
@@ -53,7 +53,7 @@ class TileGemm82 {
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size, const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
static_assert(0 < M <= 8);
|
||||
static_assert(0 < M && M <= 8);
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
|
||||
kv_cache_t* __restrict__ curr_b_0 = b_tile;
|
||||
|
||||
@@ -68,7 +68,7 @@ class TileGemm161 {
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size, const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
static_assert(0 < M <= 16);
|
||||
static_assert(0 < M && M <= 16);
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
|
||||
kv_cache_t* __restrict__ curr_b_0 = b_tile;
|
||||
|
||||
@@ -30,13 +30,15 @@
|
||||
}()
|
||||
|
||||
namespace {
|
||||
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
|
||||
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
|
||||
|
||||
FusedMOEAct get_act_type(const std::string& act) {
|
||||
if (act == "silu") {
|
||||
return FusedMOEAct::SiluAndMul;
|
||||
} else if (act == "swigluoai") {
|
||||
return FusedMOEAct::SwigluOAIAndMul;
|
||||
} else if (act == "gelu") {
|
||||
return FusedMOEAct::GeluAndMul;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid act type: " + act);
|
||||
}
|
||||
@@ -104,6 +106,43 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
const int32_t m_size, const int32_t n_size,
|
||||
const int32_t input_stride, const int32_t output_stride) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
const int32_t dim = n_size / 2;
|
||||
float* __restrict__ gate = input;
|
||||
float* __restrict__ up = input + dim;
|
||||
vec_op::FP32Vec16 one_vec(1.0);
|
||||
vec_op::FP32Vec16 w1_vec(M_SQRT1_2);
|
||||
vec_op::FP32Vec16 w2_vec(0.5);
|
||||
alignas(64) float temp[16];
|
||||
|
||||
DEFINE_FAST_EXP
|
||||
|
||||
for (int32_t m = 0; m < m_size; ++m) {
|
||||
for (int32_t n = 0; n < dim; n += 16) {
|
||||
vec_op::FP32Vec16 gate_vec(gate + n);
|
||||
vec_op::FP32Vec16 up_vec(up + n);
|
||||
auto er_input_vec = gate_vec * w1_vec;
|
||||
|
||||
er_input_vec.save(temp);
|
||||
for (int32_t i = 0; i < 16; ++i) {
|
||||
temp[i] = std::erf(temp[i]);
|
||||
}
|
||||
vec_op::FP32Vec16 er_vec(temp);
|
||||
auto gelu = gate_vec * w2_vec * (one_vec + er_vec);
|
||||
auto gated_output_fp32 = up_vec * gelu;
|
||||
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
|
||||
gated_output.save(output + n);
|
||||
}
|
||||
gate += input_stride;
|
||||
up += input_stride;
|
||||
output += output_stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
float* __restrict__ input,
|
||||
@@ -118,6 +157,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
case FusedMOEAct::SiluAndMul:
|
||||
silu_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
case FusedMOEAct::GeluAndMul:
|
||||
gelu_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported act type.");
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ Generate CPU attention dispatch switch cases and kernel instantiations.
|
||||
import os
|
||||
|
||||
# Head dimensions divisible by 32 (support all ISAs)
|
||||
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256, 512]
|
||||
|
||||
# Head dimensions divisible by 16 but not 32 (VEC16 only)
|
||||
HEAD_DIMS_16 = [80, 112]
|
||||
|
||||
@@ -39,7 +39,7 @@ class TileGemm82 {
|
||||
|
||||
template <int32_t M>
|
||||
static void gemm_micro(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
static_assert(0 < M <= 8);
|
||||
static_assert(0 < M && M <= 8);
|
||||
using load_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
|
||||
scalar_t* __restrict__ curr_b_0 = b_ptr;
|
||||
|
||||
@@ -8,8 +8,6 @@
|
||||
// libraries use different ISAs.
|
||||
#define TORCH_EXTENSION_NAME _C
|
||||
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
void release_dnnl_matmul_handler(int64_t handler);
|
||||
|
||||
int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
|
||||
@@ -354,7 +352,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"str act, str isa) -> ()");
|
||||
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
|
||||
#endif
|
||||
ops.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||
ops.def(
|
||||
"mla_decode_kvcache("
|
||||
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||
|
||||
@@ -21,150 +21,6 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef VLLM_NUMA_DISABLED
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str());
|
||||
TORCH_CHECK(omp_cpu_mask != nullptr,
|
||||
"Failed to parse CPU string: " + cpu_ids);
|
||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||
std::vector<int> omp_cpu_ids;
|
||||
omp_cpu_ids.reserve(omp_cpu_mask->size);
|
||||
|
||||
constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
|
||||
|
||||
for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
|
||||
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
|
||||
int i = 0;
|
||||
while (group_mask) {
|
||||
if (group_mask & 1) {
|
||||
omp_cpu_ids.emplace_back(offset + i);
|
||||
}
|
||||
++i;
|
||||
group_mask >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Memory node binding
|
||||
if (numa_available() != -1) {
|
||||
std::set<int> node_ids;
|
||||
for (const auto& cpu_id : omp_cpu_ids) {
|
||||
int node_id = numa_node_of_cpu(cpu_id);
|
||||
if (node_id != -1) {
|
||||
node_ids.insert(node_id);
|
||||
}
|
||||
}
|
||||
// Concatenate all node_ids into a single comma-separated string
|
||||
if (!node_ids.empty()) {
|
||||
std::string node_ids_str;
|
||||
for (const int node_id : node_ids) {
|
||||
if (!node_ids_str.empty()) {
|
||||
node_ids_str += ",";
|
||||
}
|
||||
node_ids_str += std::to_string(node_id);
|
||||
}
|
||||
|
||||
bitmask* mask = numa_parse_nodestring(node_ids_str.c_str());
|
||||
bitmask* src_mask = numa_get_mems_allowed();
|
||||
|
||||
int pid = getpid();
|
||||
|
||||
if (mask && src_mask) {
|
||||
// move all existing pages to the specified numa node.
|
||||
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
|
||||
int page_num = numa_migrate_pages(pid, src_mask, mask);
|
||||
if (page_num == -1) {
|
||||
TORCH_WARN("numa_migrate_pages failed. errno: " +
|
||||
std::to_string(errno));
|
||||
}
|
||||
|
||||
// Restrict memory allocation to the selected NUMA node(s).
|
||||
// Enhances memory locality for the threads bound to those NUMA CPUs.
|
||||
if (node_ids.size() > 1) {
|
||||
errno = 0;
|
||||
numa_set_interleave_mask(mask);
|
||||
if (errno != 0) {
|
||||
TORCH_WARN("numa_set_interleave_mask failed. errno: " +
|
||||
std::to_string(errno));
|
||||
} else {
|
||||
TORCH_WARN(
|
||||
"NUMA binding: Using INTERLEAVE policy for memory "
|
||||
"allocation across multiple NUMA nodes (nodes: " +
|
||||
node_ids_str +
|
||||
"). Memory allocations will be "
|
||||
"interleaved across the specified NUMA nodes.");
|
||||
}
|
||||
} else {
|
||||
errno = 0;
|
||||
numa_set_membind(mask);
|
||||
if (errno != 0) {
|
||||
TORCH_WARN("numa_set_membind failed. errno: " +
|
||||
std::to_string(errno));
|
||||
} else {
|
||||
TORCH_WARN(
|
||||
"NUMA binding: Using MEMBIND policy for memory "
|
||||
"allocation on the NUMA nodes (" +
|
||||
node_ids_str +
|
||||
"). Memory allocations will be "
|
||||
"strictly bound to these NUMA nodes.");
|
||||
}
|
||||
}
|
||||
|
||||
numa_set_strict(1);
|
||||
|
||||
numa_free_nodemask(mask);
|
||||
numa_free_nodemask(src_mask);
|
||||
} else {
|
||||
TORCH_WARN(
|
||||
"numa_parse_nodestring or numa_get_run_node_mask failed. errno: " +
|
||||
std::to_string(errno));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OMP threads binding
|
||||
omp_set_num_threads((int)omp_cpu_ids.size());
|
||||
torch::set_num_threads((int)omp_cpu_ids.size());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
|
||||
|
||||
std::vector<std::pair<int, int>> thread_core_mapping;
|
||||
thread_core_mapping.reserve(omp_cpu_ids.size());
|
||||
omp_lock_t writelock;
|
||||
omp_init_lock(&writelock);
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
|
||||
cpu_set_t mask;
|
||||
CPU_ZERO(&mask);
|
||||
CPU_SET(omp_cpu_ids[i], &mask);
|
||||
int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask);
|
||||
if (ret == -1) {
|
||||
TORCH_CHECK(false,
|
||||
"sched_setaffinity failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
omp_set_lock(&writelock);
|
||||
thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]);
|
||||
omp_unset_lock(&writelock);
|
||||
}
|
||||
|
||||
omp_destroy_lock(&writelock);
|
||||
|
||||
numa_free_nodemask(omp_cpu_mask);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "OMP threads binding of Process " << getpid() << ":\n";
|
||||
std::sort(thread_core_mapping.begin(), thread_core_mapping.end(),
|
||||
[](auto&& a, auto&& b) { return a.second < b.second; });
|
||||
for (auto&& item : thread_core_mapping) {
|
||||
ss << "\t"
|
||||
<< "OMP tid: " << item.first << ", core " << item.second << "\n";
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
#endif // VLLM_NUMA_DISABLED
|
||||
|
||||
namespace cpu_utils {
|
||||
ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) {
|
||||
this->realloc(allocation_unit * 128);
|
||||
|
||||
@@ -55,7 +55,8 @@ struct Counter {
|
||||
|
||||
inline int64_t get_available_l2_size() {
|
||||
static int64_t size = []() {
|
||||
const uint32_t l2_cache_size = at::cpu::L2_cache_size();
|
||||
auto caps = at::cpu::get_cpu_capabilities();
|
||||
const uint32_t l2_cache_size = caps.at("l2_cache_size").toInt();
|
||||
return l2_cache_size >> 1; // use 50% of L2 cache
|
||||
}();
|
||||
return size;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <cassert>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <torch/all.h>
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
|
||||
@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
|
||||
@@ -1,6 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
// This header is shared between _C (unstable ABI, used by machete) and
|
||||
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
|
||||
// is defined only for the stable target, so we switch includes and types
|
||||
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
|
||||
using TorchTensor = torch::stable::Tensor;
|
||||
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
|
||||
#else
|
||||
#include <torch/all.h>
|
||||
using TorchTensor = torch::Tensor;
|
||||
#define TORCH_UTILS_CHECK TORCH_CHECK
|
||||
#endif
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
@@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
||||
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||
// strides are set to be 0 or 1.
|
||||
template <typename Stride>
|
||||
static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
static inline auto make_cute_layout(TorchTensor const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
|
||||
auto stride = cute::transform_with_idx(
|
||||
Stride{}, [&](auto const& stride_ele, auto const& idx) {
|
||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
|
||||
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
|
||||
auto const& idx) {
|
||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||
|
||||
if (idx < tensor.dim()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
|
||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||
return StrideEle{};
|
||||
} else {
|
||||
if (tensor.size(idx) == 1) {
|
||||
// use 0 stride for dim with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
return tensor.stride(idx);
|
||||
}
|
||||
}
|
||||
if (idx < tensor.dim()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
|
||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||
return StrideEle{};
|
||||
} else {
|
||||
if (tensor.size(idx) == 1) {
|
||||
// use 0 stride for dim with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||
}
|
||||
return StrideEle{};
|
||||
return tensor.stride(idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||
}
|
||||
return StrideEle{};
|
||||
}
|
||||
});
|
||||
|
||||
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
||||
if (idx < tensor.dim())
|
||||
@@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
|
||||
template <typename Stride>
|
||||
static inline auto maybe_make_cute_layout(
|
||||
std::optional<torch::Tensor> const& tensor,
|
||||
std::optional<TorchTensor> const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||
|
||||
@@ -121,12 +136,12 @@ template <typename T>
|
||||
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<c10::Half> {
|
||||
struct equivalent_cutlass_type<torch::headeronly::Half> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<c10::BFloat16> {
|
||||
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
@@ -134,8 +149,8 @@ struct equivalent_cutlass_type<c10::BFloat16> {
|
||||
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
|
||||
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
|
||||
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
|
||||
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
|
||||
template <typename T>
|
||||
struct equivalent_scalar_type {
|
||||
using type = T;
|
||||
@@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::half_t> {
|
||||
using type = c10::Half;
|
||||
using type = torch::headeronly::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
||||
using type = c10::BFloat16;
|
||||
using type = torch::headeronly::BFloat16;
|
||||
};
|
||||
|
||||
// get equivalent c10::ScalarType tag from compile time type
|
||||
// get equivalent torch::headeronly::ScalarType tag from compile time type
|
||||
template <typename T>
|
||||
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
|
||||
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
|
||||
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
|
||||
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
|
||||
|
||||
@@ -49,6 +49,15 @@
|
||||
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||
|
||||
// Half types dispatch (Half + BFloat16)
|
||||
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
|
||||
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
|
||||
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
|
||||
|
||||
// Boolean dispatch
|
||||
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
if (expr) { \
|
||||
|
||||
@@ -84,4 +84,54 @@ void get_cutlass_batched_moe_mm_data(
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
// FP4/NVFP4 ops
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha);
|
||||
|
||||
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
|
||||
const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets);
|
||||
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
|
||||
|
||||
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout,
|
||||
torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& output_scale);
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor& output_block_scale,
|
||||
torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& input_global_scale);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
@@ -41,7 +40,7 @@ __global__ void get_group_gemm_starts(
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
|
||||
cutlass::Array<cutlass::float_e4m3_t, 8>> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
@@ -66,23 +65,34 @@ __global__ void get_group_gemm_starts(
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
|
||||
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
|
||||
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_group_scales.dtype() ==
|
||||
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
|
||||
TORCH_CHECK(out_tensors.dtype() ==
|
||||
torch::kBFloat16); // only support bf16 for now
|
||||
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
|
||||
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
|
||||
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
|
||||
torch::stable::Tensor& b_group_scales_ptrs,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) {
|
||||
STD_TORCH_CHECK(a_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Int); // int4 8x packed into int32
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(
|
||||
b_group_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
|
||||
// type is e4m3
|
||||
STD_TORCH_CHECK(
|
||||
out_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
// logical k, n
|
||||
@@ -90,15 +100,16 @@ void run_get_group_gemm_starts(
|
||||
int64_t k = a_tensors.size(1);
|
||||
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
|
||||
cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
@@ -14,13 +14,12 @@
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
// vllm includes
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "get_group_starts.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "w4a8_utils.cuh"
|
||||
@@ -168,31 +167,40 @@ struct W4A8GroupedGemmKernel {
|
||||
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
|
||||
"LayoutB_Reordered size must be divisible by 4 bytes");
|
||||
|
||||
static void grouped_mm(
|
||||
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides) {
|
||||
static void grouped_mm(torch::stable::Tensor& out_tensors,
|
||||
const torch::stable::Tensor& a_tensors,
|
||||
const torch::stable::Tensor& b_tensors,
|
||||
const torch::stable::Tensor& a_scales,
|
||||
const torch::stable::Tensor& b_scales,
|
||||
const torch::stable::Tensor& b_group_scales,
|
||||
const int64_t b_group_size,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& problem_sizes_torch,
|
||||
const torch::stable::Tensor& a_strides,
|
||||
const torch::stable::Tensor& b_strides,
|
||||
const torch::stable::Tensor& c_strides,
|
||||
const torch::stable::Tensor& group_scale_strides) {
|
||||
auto device = a_tensors.device();
|
||||
auto device_id = device.index();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device_id);
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(device_id);
|
||||
auto stream = get_current_cuda_stream(device_id);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
int n = static_cast<int>(b_tensors.size(1));
|
||||
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(device);
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor out_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
|
||||
// get the correct offsets to pass to gemm
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
@@ -247,9 +255,9 @@ struct W4A8GroupedGemmKernel {
|
||||
|
||||
// Allocate workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
torch::stable::Tensor workspace = torch::stable::empty(
|
||||
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
|
||||
device);
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
@@ -294,14 +302,20 @@ using Kernel_256x128_2x1x1_Coop =
|
||||
using Kernel_128x256_2x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||
|
||||
void mm_dispatch(
|
||||
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides, const std::string& schedule) {
|
||||
void mm_dispatch(torch::stable::Tensor& out_tensors,
|
||||
const torch::stable::Tensor& a_tensors,
|
||||
const torch::stable::Tensor& b_tensors,
|
||||
const torch::stable::Tensor& a_scales,
|
||||
const torch::stable::Tensor& b_scales,
|
||||
const torch::stable::Tensor& b_group_scales,
|
||||
const int64_t b_group_size,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& a_strides,
|
||||
const torch::stable::Tensor& b_strides,
|
||||
const torch::stable::Tensor& c_strides,
|
||||
const torch::stable::Tensor& group_scale_strides,
|
||||
const std::string& schedule) {
|
||||
if (schedule == "Kernel_128x16_1x1x1_Coop") {
|
||||
Kernel_128x16_1x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
@@ -358,18 +372,23 @@ void mm_dispatch(
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
|
||||
STD_TORCH_CHECK(false,
|
||||
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
|
||||
}
|
||||
}
|
||||
|
||||
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides,
|
||||
void mm(torch::stable::Tensor& out_tensors,
|
||||
const torch::stable::Tensor& a_tensors,
|
||||
const torch::stable::Tensor& b_tensors,
|
||||
const torch::stable::Tensor& a_scales,
|
||||
const torch::stable::Tensor& b_scales,
|
||||
const torch::stable::Tensor& b_group_scales, const int64_t b_group_size,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& a_strides,
|
||||
const torch::stable::Tensor& b_strides,
|
||||
const torch::stable::Tensor& c_strides,
|
||||
const torch::stable::Tensor& group_scale_strides,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// user has specified a schedule
|
||||
if (maybe_schedule) {
|
||||
@@ -406,26 +425,27 @@ void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
a_strides, b_strides, c_strides, group_scale_strides, schedule);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
||||
torch::Tensor const& b_tensors) {
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
|
||||
TORCH_CHECK(b_tensors.is_contiguous());
|
||||
TORCH_CHECK(b_tensors.is_cuda());
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
|
||||
encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
|
||||
STD_TORCH_CHECK(b_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
|
||||
STD_TORCH_CHECK(b_tensors.is_contiguous());
|
||||
STD_TORCH_CHECK(b_tensors.is_cuda());
|
||||
|
||||
int n = static_cast<int>(b_tensors.size(1));
|
||||
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
|
||||
|
||||
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
|
||||
// These misalignments cause silent OOB unless run under Compute Sanitizer.
|
||||
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
|
||||
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
|
||||
STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
|
||||
STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
|
||||
|
||||
// we will store the layout to an int32 tensor;
|
||||
// this is the number of elements we need per layout
|
||||
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
|
||||
|
||||
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
|
||||
torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors);
|
||||
int num_experts = static_cast<int>(b_tensors.size(0));
|
||||
|
||||
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
|
||||
@@ -435,7 +455,7 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
||||
size_t num_int4_elems = 1ull * num_experts * n * k;
|
||||
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
|
||||
num_int4_elems);
|
||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
|
||||
// construct the layout once; assumes each expert has the same layout
|
||||
using LayoutType = LayoutB_Reordered;
|
||||
@@ -456,28 +476,28 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
||||
}
|
||||
|
||||
// save the packed layout to torch tensor so we can re-use it
|
||||
auto cpu_opts =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
|
||||
torch::Tensor layout_cpu =
|
||||
torch::empty({num_experts, layout_width}, cpu_opts);
|
||||
torch::stable::Tensor layout_cpu = torch::stable::empty(
|
||||
{num_experts, layout_width}, torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU));
|
||||
|
||||
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
|
||||
int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>();
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
std::memcpy(layout_data + i * layout_width, // dst (int32*)
|
||||
&layout_B_reordered, // src (LayoutType*)
|
||||
sizeof(LayoutType)); // number of bytes
|
||||
}
|
||||
|
||||
torch::Tensor packed_layout =
|
||||
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
|
||||
torch::stable::Tensor packed_layout =
|
||||
torch::stable::to(layout_cpu, b_tensors.device(),
|
||||
/*non_blocking=*/false);
|
||||
|
||||
return {b_tensors_packed, packed_layout};
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_moe_mm", &mm);
|
||||
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm));
|
||||
m.impl("cutlass_encode_and_reorder_int4b_grouped",
|
||||
TORCH_BOX(&encode_and_reorder_int4b));
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8_moe
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -3,14 +3,12 @@
|
||||
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
|
||||
//
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "w4a8_utils.cuh"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <limits>
|
||||
|
||||
@@ -161,31 +159,31 @@ struct W4A8GemmKernel {
|
||||
using StrideD = typename GemmKernelShuffled::StrideD;
|
||||
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||
|
||||
static torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type) {
|
||||
static torch::stable::Tensor mm(
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, // already packed
|
||||
torch::stable::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||
torch::stable::Tensor const& token_scales,
|
||||
std::optional<torch::headeronly::ScalarType> const& maybe_out_type) {
|
||||
// TODO: param validation
|
||||
int m = A.size(0);
|
||||
int k = A.size(1);
|
||||
int n = B.size(1);
|
||||
|
||||
// safely cast group_size to int
|
||||
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
|
||||
"group_size out of supported range for int: ", group_size);
|
||||
STD_TORCH_CHECK(
|
||||
group_size > 0 && group_size <= std::numeric_limits<int>::max(),
|
||||
"group_size out of supported range for int: ", group_size);
|
||||
int const group_size_int = static_cast<int>(group_size);
|
||||
|
||||
// Allocate output
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
A.get_device_index());
|
||||
auto device = A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
torch::Tensor D =
|
||||
torch::empty({m, n}, torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<ElementD>)
|
||||
.device(device));
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
torch::stable::Tensor D = torch::stable::empty(
|
||||
{m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device);
|
||||
// prepare arg pointers
|
||||
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
@@ -237,9 +235,9 @@ struct W4A8GemmKernel {
|
||||
|
||||
// Workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
torch::stable::Tensor workspace = torch::stable::empty(
|
||||
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
|
||||
device);
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
@@ -269,14 +267,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
|
||||
|
||||
torch::Tensor mm_dispatch(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
const std::string& schedule) {
|
||||
torch::stable::Tensor mm_dispatch(
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, // already packed
|
||||
torch::stable::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||
torch::stable::Tensor const& token_scales,
|
||||
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
|
||||
const std::string& schedule) {
|
||||
if (schedule == "256x128_1x1x1") {
|
||||
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
@@ -318,17 +316,18 @@ torch::Tensor mm_dispatch(torch::Tensor const& A,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||
STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
torch::stable::Tensor mm(
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, // already packed
|
||||
torch::stable::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||
torch::stable::Tensor const& token_scales,
|
||||
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// requested a specific schedule
|
||||
if (maybe_schedule) {
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
@@ -378,14 +377,15 @@ torch::Tensor mm(torch::Tensor const& A,
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pre-processing utils
|
||||
// ----------------------------------------------------------------------------
|
||||
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
TORCH_CHECK(scales.is_cuda());
|
||||
torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
|
||||
STD_TORCH_CHECK(scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(scales.is_contiguous());
|
||||
STD_TORCH_CHECK(scales.is_cuda());
|
||||
|
||||
auto packed_scales = torch::empty(
|
||||
{scales.numel() * ScalePackSize},
|
||||
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
|
||||
auto packed_scales =
|
||||
torch::stable::empty({scales.numel() * ScalePackSize},
|
||||
scales.scalar_type(), std::nullopt, scales.device());
|
||||
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
|
||||
auto packed_scales_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
|
||||
@@ -396,15 +396,16 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
return packed_scales;
|
||||
}
|
||||
|
||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(B.dim() == 2);
|
||||
torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
|
||||
STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(B.dim() == 2);
|
||||
|
||||
torch::Tensor B_packed = torch::empty_like(B);
|
||||
torch::stable::Tensor B_packed = torch::stable::empty_like(B);
|
||||
|
||||
int k = B.size(0) * PackFactor; // logical k
|
||||
int n = B.size(1);
|
||||
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
|
||||
STD_TORCH_CHECK((n * k) % 32 == 0,
|
||||
"need multiples of 32 int4s for 16B chunks");
|
||||
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
||||
@@ -415,16 +416,17 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
|
||||
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
|
||||
n * k);
|
||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||
|
||||
return B_packed;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_mm", &mm);
|
||||
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
|
||||
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm));
|
||||
m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8));
|
||||
m.impl("cutlass_encode_and_reorder_int4b",
|
||||
TORCH_BOX(&encode_and_reorder_int4b));
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8
|
||||
} // namespace vllm::cutlass_w4a8
|
||||
@@ -14,16 +14,15 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
@@ -118,17 +117,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
torch::Tensor& output_sf,
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& input_sf) {
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(
|
||||
torch::stable::Tensor& output, // [..., d]
|
||||
torch::stable::Tensor& output_sf,
|
||||
torch::stable::Tensor& input, // [..., 2 * d]
|
||||
torch::stable::Tensor& input_sf) {
|
||||
int32_t m = input.size(0);
|
||||
int32_t n = input.size(1) / 2;
|
||||
|
||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
STD_TORCH_CHECK(
|
||||
input.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
|
||||
int multiProcessorCount =
|
||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||
@@ -136,8 +137,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
auto stream = get_current_cuda_stream(input.get_device_index());
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
@@ -149,7 +151,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
@@ -14,14 +14,12 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "core/registration.h"
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
@@ -122,7 +120,7 @@ __global__ void __get_group_gemm_starts(
|
||||
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
|
||||
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
|
||||
LayoutSFB, ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
|
||||
LayoutSFA, LayoutSFB, ScaleConfig> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
@@ -150,50 +148,64 @@ __global__ void __get_group_gemm_starts(
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_group_gemm_starts(
|
||||
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
|
||||
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
|
||||
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
|
||||
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
|
||||
int64_t c_stride_val,
|
||||
/*these are used for their base addresses*/
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& alphas,
|
||||
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
|
||||
torch::Tensor const& problem_sizes, int M, int N, int K) {
|
||||
void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
|
||||
const torch::stable::Tensor& b_starts,
|
||||
const torch::stable::Tensor& out_starts,
|
||||
const torch::stable::Tensor& a_scales_starts,
|
||||
const torch::stable::Tensor& b_scales_starts,
|
||||
const torch::stable::Tensor& alpha_starts,
|
||||
const torch::stable::Tensor& layout_sfa,
|
||||
const torch::stable::Tensor& layout_sfb,
|
||||
const torch::stable::Tensor& a_strides,
|
||||
const torch::stable::Tensor& b_strides,
|
||||
const torch::stable::Tensor& c_strides,
|
||||
int64_t a_stride_val, int64_t b_stride_val,
|
||||
int64_t c_stride_val,
|
||||
/*these are used for their base addresses*/
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& out_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& alphas,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& sf_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
int M, int N, int K) {
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
TORCH_CHECK(out_tensors.size(1) == N,
|
||||
"Output tensor shape doesn't match expected shape");
|
||||
TORCH_CHECK(K / 2 == b_tensors.size(2),
|
||||
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
||||
" dimension must match");
|
||||
STD_TORCH_CHECK(out_tensors.size(1) == N,
|
||||
"Output tensor shape doesn't match expected shape");
|
||||
STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
|
||||
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
||||
" dimension must match");
|
||||
if (false) {
|
||||
}
|
||||
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
|
||||
// ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
||||
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
|
||||
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
cutlass::float_e2m1_t, cutlass::float_ue4m3_t,
|
||||
torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
|
||||
cutlass::float_ue4m3_t, torch::kFloat16,
|
||||
half, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
cutlass::float_ue4m3_t,
|
||||
torch::headeronly::ScalarType::Half, half,
|
||||
LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||
int N, int K) {
|
||||
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||
using ElementType = cutlass::float_e2m1_t;
|
||||
@@ -272,20 +284,40 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::stable::Tensor a_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor out_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor a_scales_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_scales_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor alpha_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor layout_sfa = torch::stable::empty(
|
||||
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||
a.device());
|
||||
torch::stable::Tensor layout_sfb = torch::stable::empty(
|
||||
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||
a.device());
|
||||
torch::stable::Tensor a_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor c_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||
@@ -308,7 +340,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||
hw_info.device_id = a.get_device();
|
||||
hw_info.device_id = a.get_device_index();
|
||||
static std::unordered_map<int, int> cached_sm_counts;
|
||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||
cached_sm_counts[hw_info.device_id] =
|
||||
@@ -350,32 +382,35 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
scheduler};
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, a.device());
|
||||
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||
STD_TORCH_CHECK(
|
||||
can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||
|
||||
// Run the GEMM
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM: status=", (int)status,
|
||||
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
||||
" M=", M, " N=", N, " K=", K);
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM: status=", (int)status,
|
||||
" workspace_size=", workspace_size,
|
||||
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
|
||||
|
||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||
int N, int K) {
|
||||
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||
using ElementType = cutlass::float_e2m1_t;
|
||||
@@ -446,20 +481,40 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::stable::Tensor a_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor out_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor a_scales_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_scales_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor alpha_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor layout_sfa = torch::stable::empty(
|
||||
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||
a.device());
|
||||
torch::stable::Tensor layout_sfb = torch::stable::empty(
|
||||
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||
a.device());
|
||||
torch::stable::Tensor a_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor c_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||
@@ -480,7 +535,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||
hw_info.device_id = a.get_device();
|
||||
hw_info.device_id = a.get_device_index();
|
||||
static std::unordered_map<int, int> cached_sm_counts;
|
||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||
cached_sm_counts[hw_info.device_id] =
|
||||
@@ -523,33 +578,36 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
scheduler};
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, a.device());
|
||||
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||
STD_TORCH_CHECK(
|
||||
can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||
|
||||
// Run the GEMM
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM: status=", (int)status,
|
||||
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
||||
" M=", M, " N=", N, " K=", K);
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM: status=", (int)status,
|
||||
" workspace_size=", workspace_size,
|
||||
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
|
||||
|
||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void run_fp4_blockwise_scaled_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||
int N, int K) {
|
||||
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||
if (version_num >= 120 && version_num < 130) {
|
||||
@@ -567,7 +625,7 @@ void run_fp4_blockwise_scaled_group_mm(
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 100 or 120");
|
||||
@@ -575,26 +633,31 @@ void run_fp4_blockwise_scaled_group_mm(
|
||||
|
||||
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
||||
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||
#endif
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||
": Inconsistency of torch::stable::Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
||||
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
||||
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
|
||||
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
|
||||
const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets) {
|
||||
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
||||
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
||||
// Input validation
|
||||
@@ -602,30 +665,34 @@ void cutlass_fp4_group_mm(
|
||||
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
||||
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
|
||||
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
|
||||
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
|
||||
CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas");
|
||||
|
||||
TORCH_CHECK(a_blockscale.dim() == 2,
|
||||
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
||||
" k // group_size], observed rank: ",
|
||||
a_blockscale.dim())
|
||||
TORCH_CHECK(b_blockscales.dim() == 3,
|
||||
"expected b_blockscale to be of shape: "
|
||||
" [num_experts, n, k // group_size], observed rank: ",
|
||||
b_blockscales.dim())
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have the shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32.");
|
||||
STD_TORCH_CHECK(
|
||||
a_blockscale.dim() == 2,
|
||||
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
||||
" k // group_size], observed rank: ",
|
||||
a_blockscale.dim())
|
||||
STD_TORCH_CHECK(b_blockscales.dim() == 3,
|
||||
"expected b_blockscale to be of shape: "
|
||||
" [num_experts, n, k // group_size], observed rank: ",
|
||||
b_blockscales.dim())
|
||||
STD_TORCH_CHECK(problem_sizes.dim() == 2,
|
||||
"problem_sizes must be a 2D tensor");
|
||||
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have the shape (num_experts, 3)");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32.");
|
||||
|
||||
int M = static_cast<int>(a.size(0));
|
||||
int N = static_cast<int>(b.size(1));
|
||||
int E = static_cast<int>(b.size(0));
|
||||
int K = static_cast<int>(2 * b.size(2));
|
||||
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
|
||||
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||
expert_offsets, sf_offsets, M, N, K);
|
||||
@@ -633,7 +700,7 @@ void cutlass_fp4_group_mm(
|
||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 120 && version_num < 130) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
|
||||
output.scalar_type());
|
||||
}
|
||||
@@ -643,7 +710,7 @@ void cutlass_fp4_group_mm(
|
||||
expert_offsets, sf_offsets, M, N, K);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
||||
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
|
||||
@@ -651,6 +718,6 @@ void cutlass_fp4_group_mm(
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm));
|
||||
}
|
||||
@@ -14,16 +14,15 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
@@ -327,25 +326,28 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
} // namespace vllm
|
||||
|
||||
/*Quantization entry for fp4 experts quantization*/
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m);
|
||||
|
||||
constexpr auto HALF = at::ScalarType::Half;
|
||||
constexpr auto BF16 = at::ScalarType::BFloat16;
|
||||
constexpr auto FLOAT = at::ScalarType::Float;
|
||||
constexpr auto INT = at::ScalarType::Int;
|
||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||
constexpr auto HALF = torch::headeronly::ScalarType::Half;
|
||||
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
|
||||
constexpr auto FLOAT = torch::headeronly::ScalarType::Float;
|
||||
constexpr auto INT = torch::headeronly::ScalarType::Int;
|
||||
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
|
||||
|
||||
// Common validation for fp4 experts quantization entry points.
|
||||
static void validate_fp4_experts_quant_inputs(
|
||||
torch::Tensor const& output, torch::Tensor const& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
||||
torch::stable::Tensor const& output,
|
||||
torch::stable::Tensor const& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
||||
int64_t k) {
|
||||
CHECK_INPUT(output, "output");
|
||||
CHECK_INPUT(output_scale, "output_scale");
|
||||
@@ -354,41 +356,42 @@ static void validate_fp4_experts_quant_inputs(
|
||||
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
|
||||
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
STD_TORCH_CHECK(output.dim() == 2);
|
||||
STD_TORCH_CHECK(output_scale.dim() == 2);
|
||||
STD_TORCH_CHECK(input.dim() == 2);
|
||||
STD_TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||
TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
STD_TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output.size(0) == m_topk);
|
||||
TORCH_CHECK(output.size(1) == k / 2);
|
||||
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
STD_TORCH_CHECK(output.size(0) == m_topk);
|
||||
STD_TORCH_CHECK(output.size(1) == k / 2);
|
||||
int scales_k = k / BLOCK_SIZE;
|
||||
// 4 means the swizzle requirement by nvidia nvfp4.
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
|
||||
@@ -397,11 +400,11 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
output_scale_offset_by_experts, m_topk, k);
|
||||
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
|
||||
@@ -413,14 +416,15 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||
auto m_topk = input.size(0);
|
||||
// Input has gate || up layout, so k = input.size(1) / 2
|
||||
auto k_times_2 = input.size(1);
|
||||
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
||||
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
||||
auto k = k_times_2 / 2;
|
||||
|
||||
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
||||
@@ -428,11 +432,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
output_scale_offset_by_experts, m_topk, k);
|
||||
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
|
||||
175
csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
Normal file
175
csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
Normal file
@@ -0,0 +1,175 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& output_sf,
|
||||
torch::stable::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& output_sf,
|
||||
torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& input_sf);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
static bool nvfp4_quant_sm_supported() {
|
||||
const int32_t sm = get_sm_version_num();
|
||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||
if (sm >= 100 && sm < 120) return true;
|
||||
#endif
|
||||
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||
if (sm >= 120 && sm < 130) return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout,
|
||||
torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& output_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled nvfp4 quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||
is_sf_swizzled_layout);
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No compiled nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
|
||||
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
int64_t n = input.size(-1);
|
||||
int64_t m = input.numel() / n;
|
||||
auto device = input.device();
|
||||
|
||||
// Two fp4 values packed into a uint8
|
||||
auto output = torch::stable::empty(
|
||||
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
|
||||
|
||||
torch::stable::Tensor output_sf;
|
||||
if (is_sf_swizzled_layout) {
|
||||
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
|
||||
output_sf = torch::stable::empty(
|
||||
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
|
||||
} else {
|
||||
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
|
||||
torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
}
|
||||
|
||||
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
|
||||
output_sf);
|
||||
return {output, output_sf};
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled nvfp4 experts quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& output_sf,
|
||||
torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& input_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled silu_and_mul nvfp4 experts quantization kernel "
|
||||
"for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
|
||||
}
|
||||
@@ -14,16 +14,16 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
@@ -173,18 +173,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& output_sf,
|
||||
torch::Tensor const& input_sf,
|
||||
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& output_sf,
|
||||
torch::stable::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
int32_t m = input.size(0);
|
||||
int32_t n = input.size(1);
|
||||
|
||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
STD_TORCH_CHECK(
|
||||
input.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
|
||||
int multiProcessorCount =
|
||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||
@@ -192,8 +193,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
auto stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
|
||||
|
||||
@@ -213,15 +215,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
// NOTE: We don't support e8m0 scales at this moment.
|
||||
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||
m, n, num_padded_cols, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||
m, n, num_padded_cols, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
} else {
|
||||
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
|
||||
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
|
||||
@@ -229,15 +231,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
// NOTE: We don't support e8m0 scales at this moment.
|
||||
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
|
||||
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
|
||||
input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -14,32 +14,39 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
|
||||
const torch::Tensor& B, const torch::Tensor& A_sf,
|
||||
const torch::Tensor& B_sf,
|
||||
const torch::Tensor& alpha) {
|
||||
// Make sure we’re on A’s device.
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
|
||||
const torch::stable::Tensor& A,
|
||||
const torch::stable::Tensor& B,
|
||||
const torch::stable::Tensor& A_sf,
|
||||
const torch::stable::Tensor& B_sf,
|
||||
const torch::stable::Tensor& alpha) {
|
||||
// Make sure we're on A's device.
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
A.get_device_index());
|
||||
const int32_t sm = get_sm_version_num();
|
||||
|
||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||
@@ -56,8 +63,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
|
||||
". Recompile with CUDA >= 12.8 and CC >= 100.");
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled nvfp4 mm kernel for SM ", sm,
|
||||
". Recompile with CUDA >= 12.8 and CC >= 100.");
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
||||
@@ -14,10 +14,9 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
@@ -127,8 +126,9 @@ struct Fp4GemmSm100 {
|
||||
|
||||
template <typename Config>
|
||||
typename Config::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||
torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
|
||||
int64_t M, int64_t N, int64_t K) {
|
||||
using ElementA = typename Config::Gemm::ElementA;
|
||||
using ElementB = typename Config::Gemm::ElementB;
|
||||
@@ -174,19 +174,20 @@ typename Config::Gemm::Arguments args_from_options(
|
||||
}
|
||||
|
||||
template <typename Config>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
typename Config::Gemm gemm;
|
||||
|
||||
auto arguments =
|
||||
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, A.device());
|
||||
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
@@ -197,12 +198,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
|
||||
// Dispatch function to select appropriate config based on M
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int64_t m,
|
||||
int64_t n, int64_t k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (mp2 <= 16) {
|
||||
@@ -222,61 +224,65 @@ void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
|
||||
#else
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int64_t m,
|
||||
int64_t n, int64_t k, cudaStream_t stream) {
|
||||
STD_TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||
": Inconsistency of torch::stable::Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha) {
|
||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||
|
||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||
|
||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
||||
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
|
||||
|
||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
|
||||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
|
||||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
|
||||
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
STD_TORCH_CHECK(A.size(1) == B.size(1),
|
||||
"a and b shapes cannot be multiplied (", A.size(0), "x",
|
||||
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
|
||||
|
||||
auto const m = A.sizes()[0];
|
||||
auto const n = B.sizes()[0];
|
||||
auto const k = A.sizes()[1] * 2;
|
||||
auto const m = A.size(0);
|
||||
auto const n = B.size(0);
|
||||
auto const k = A.size(1) * 2;
|
||||
|
||||
constexpr int alignment = 32;
|
||||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
|
||||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
|
||||
"), k: ", k, ".");
|
||||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
|
||||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
|
||||
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
|
||||
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
|
||||
"), k: ", k, ".");
|
||||
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
|
||||
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
|
||||
").");
|
||||
|
||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||
int rounded_m = round_up(m, 128);
|
||||
@@ -285,33 +291,34 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
// integer.
|
||||
int rounded_k = round_up(k / 16, 4);
|
||||
|
||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
|
||||
"x", B_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
|
||||
A_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
|
||||
B_sf.sizes()[1], ")");
|
||||
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
|
||||
B_sf.size(1), ")");
|
||||
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
|
||||
A_sf.size(1), ")");
|
||||
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
|
||||
B_sf.size(1), ")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
auto out_dtype = D.scalar_type();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
A.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
if (out_dtype == torch::headeronly::ScalarType::Half) {
|
||||
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
||||
k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
} else if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
||||
m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
|
||||
")");
|
||||
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (",
|
||||
out_dtype, ")");
|
||||
}
|
||||
}
|
||||
@@ -14,10 +14,9 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
@@ -34,19 +33,20 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||
": Inconsistency of torch::stable::Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||
|
||||
struct sm120_fp4_config_M256 {
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
@@ -109,12 +109,13 @@ struct Fp4GemmSm120 {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
|
||||
at::Tensor const& B,
|
||||
at::Tensor const& A_sf,
|
||||
at::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int M,
|
||||
int N, int K) {
|
||||
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha,
|
||||
int M, int N, int K) {
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
|
||||
}
|
||||
|
||||
template <typename Gemm>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int M, int N, int K,
|
||||
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int M, int N, int K,
|
||||
cudaStream_t stream) {
|
||||
Gemm gemm;
|
||||
|
||||
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, A.device());
|
||||
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int m, int n,
|
||||
int k, cudaStream_t stream) {
|
||||
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int m,
|
||||
int n, int k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
if (mp2 <= 256) {
|
||||
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
|
||||
@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int m, int n,
|
||||
int k, cudaStream_t stream) {
|
||||
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int m,
|
||||
int n, int k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
if (mp2 <= 256) {
|
||||
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
|
||||
@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||
@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||
|
||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
||||
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
|
||||
|
||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
|
||||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
|
||||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
|
||||
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
STD_TORCH_CHECK(A.size(1) == B.size(1),
|
||||
"a and b shapes cannot be multiplied (", A.size(0), "x",
|
||||
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
|
||||
|
||||
auto const m = A.sizes()[0];
|
||||
auto const n = B.sizes()[0];
|
||||
auto const k = A.sizes()[1] * 2;
|
||||
auto const m = A.size(0);
|
||||
auto const n = B.size(0);
|
||||
auto const k = A.size(1) * 2;
|
||||
|
||||
constexpr int alignment = 32;
|
||||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
|
||||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
|
||||
"), k: ", k, ".");
|
||||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
|
||||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
|
||||
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
|
||||
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
|
||||
"), k: ", k, ".");
|
||||
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
|
||||
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
|
||||
").");
|
||||
|
||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||
int rounded_m = round_up(m, 128);
|
||||
@@ -248,38 +254,39 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
||||
// integer.
|
||||
int rounded_k = round_up(k / 16, 4);
|
||||
|
||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
|
||||
"x", B_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
|
||||
A_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
|
||||
B_sf.sizes()[1], ")");
|
||||
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
|
||||
B_sf.size(1), ")");
|
||||
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
|
||||
A_sf.size(1), ")");
|
||||
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
|
||||
B_sf.size(1), ")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
auto out_dtype = D.scalar_type();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
A.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
|
||||
|
||||
if (out_dtype == at::ScalarType::BFloat16) {
|
||||
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
||||
stream);
|
||||
} else if (out_dtype == at::ScalarType::Half) {
|
||||
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
|
||||
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
|
||||
out_dtype, ")");
|
||||
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
|
||||
out_dtype, ")");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
STD_TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@
|
||||
#include <cuda_fp8.h>
|
||||
#include <utility>
|
||||
|
||||
#include "../../cuda_vec_utils.cuh"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
|
||||
CUDA_VERSION >= 12090
|
||||
@@ -26,8 +26,10 @@ using namespace cute;
|
||||
template <class OutType, int ScaleGranularityM,
|
||||
int ScaleGranularityN, int ScaleGranularityK,
|
||||
class MmaTileShape, class ClusterShape,
|
||||
class EpilogueScheduler, class MainloopScheduler>
|
||||
class EpilogueScheduler, class MainloopScheduler,
|
||||
bool swap_ab_ = false>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
static constexpr bool swap_ab = swap_ab_;
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementA = ElementAB;
|
||||
@@ -55,9 +57,13 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||
using ScaleConfig = conditional_t<swap_ab,
|
||||
cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
|
||||
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
|
||||
cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
|
||||
|
||||
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
@@ -78,17 +84,32 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
|
||||
AlignmentD,
|
||||
EpilogueScheduler,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||
using CollectiveMainloop =
|
||||
using CollectiveMainloop = conditional_t<swap_ab,
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB_Transpose, LayoutSFA>,
|
||||
AlignmentB,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA_Transpose, LayoutSFB>,
|
||||
AlignmentA,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduler
|
||||
>::CollectiveOp,
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
@@ -103,7 +124,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduler
|
||||
>::CollectiveOp;
|
||||
>::CollectiveOp>;
|
||||
|
||||
// SM12x family to support both SM120 (RTX 5090) and SM121 (DGX Spark)
|
||||
using KernelType = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
|
||||
@@ -115,7 +136,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
// Tile configurations for different M ranges
|
||||
template <typename OutType>
|
||||
struct sm120_blockwise_fp8_config_default {
|
||||
// M > 256: use 128x128x128 tile with Cooperative (Auto) schedule
|
||||
// use 128x128x128 tile with Cooperative (Auto) schedule
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
@@ -127,8 +148,8 @@ struct sm120_blockwise_fp8_config_default {
|
||||
};
|
||||
|
||||
template <typename OutType>
|
||||
struct sm120_blockwise_fp8_config_M64 {
|
||||
// M in [1, 256]: use 64x128x128 tile with Pingpong schedule
|
||||
struct sm120_blockwise_fp8_config_pingpong {
|
||||
// use 64x128x128 tile with Pingpong schedule
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
@@ -139,11 +160,24 @@ struct sm120_blockwise_fp8_config_M64 {
|
||||
EpilogueSchedule, KernelSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType>
|
||||
struct sm120_blockwise_fp8_config_swapab {
|
||||
// use 128x32x128 tile with Cooperative schedule
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Gemm = cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 128, 1, 128, TileShape, ClusterShape,
|
||||
EpilogueSchedule, KernelSchedule, true>;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
@@ -167,11 +201,13 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
|
||||
b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
|
||||
|
||||
LayoutSFA layout_SFA =
|
||||
LayoutSFA layout_SFA = swap_ab ?
|
||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
|
||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_SFB =
|
||||
LayoutSFB layout_SFB = swap_ab ?
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
@@ -180,15 +216,24 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
|
||||
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{};
|
||||
mainloop_args.ptr_A = a_ptr;
|
||||
mainloop_args.dA = a_stride;
|
||||
mainloop_args.ptr_B = b_ptr;
|
||||
mainloop_args.dB = b_stride;
|
||||
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||
mainloop_args.layout_SFA = layout_SFA;
|
||||
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||
mainloop_args.layout_SFB = layout_SFB;
|
||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||
if (swap_ab) {
|
||||
mainloop_args.ptr_A = b_ptr;
|
||||
mainloop_args.dA = b_stride;
|
||||
mainloop_args.ptr_B = a_ptr;
|
||||
mainloop_args.dB = a_stride;
|
||||
mainloop_args.ptr_SFA = b_scales_ptr;
|
||||
mainloop_args.ptr_SFB = a_scales_ptr;
|
||||
} else {
|
||||
mainloop_args.ptr_A = a_ptr;
|
||||
mainloop_args.dA = a_stride;
|
||||
mainloop_args.ptr_B = b_ptr;
|
||||
mainloop_args.dB = b_stride;
|
||||
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||
}
|
||||
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
@@ -204,15 +249,26 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
int M = a.size(0);
|
||||
if (M <= 256) {
|
||||
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
|
||||
// more heuristic tuning can be done here by checking N/K dimensions as well
|
||||
bool swap_ab = (M <= 64) || (M % 4 != 0);
|
||||
|
||||
if (!swap_ab) {
|
||||
if (M <= 256) {
|
||||
using Gemm = typename sm120_blockwise_fp8_config_pingpong<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
// M > 256: use default 128x128x128 config with Cooperative (Auto) schedule
|
||||
using Gemm = typename sm120_blockwise_fp8_config_default<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
// Swap A/B for small M to improve performance
|
||||
// Use TILE_N=32 as the minimum compatible tile size.
|
||||
using Gemm = typename sm120_blockwise_fp8_config_swapab<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
// M > 256: use default 128x128x128 config with Cooperative (Auto) schedule
|
||||
using Gemm = typename sm120_blockwise_fp8_config_default<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
||||
|
||||
// Out variant
|
||||
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
|
||||
// This registration is now migrated to stable ABI
|
||||
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
|
||||
// yet in torch/headeronly), the tag should be applied from Python
|
||||
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
|
||||
// with the .impl remaining in C++.
|
||||
// See pytorch/pytorch#176117.
|
||||
ops.def(
|
||||
"scaled_fp4_quant.out(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
||||
"-> ()");
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
|
||||
// Fused SiLU+Mul+NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
|
||||
"output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
|
||||
// Fused SiLU+Mul+NVFP4 quantization.
|
||||
ops.def(
|
||||
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
|
||||
"Tensor input, Tensor input_global_scale) -> ()");
|
||||
|
||||
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
|
||||
// of the given capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_mm("
|
||||
" Tensor A,"
|
||||
" Tensor B,"
|
||||
" Tensor group_scales,"
|
||||
" int group_size,"
|
||||
" Tensor channel_scales,"
|
||||
" Tensor token_scales,"
|
||||
" ScalarType? out_type,"
|
||||
" str? maybe_schedule"
|
||||
") -> Tensor");
|
||||
|
||||
// pack scales
|
||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||
|
||||
// encode and reorder weight matrix
|
||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||
|
||||
// CUTLASS w4a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_moe_mm("
|
||||
" Tensor! out_tensors,"
|
||||
" Tensor a_tensors,"
|
||||
" Tensor b_tensors,"
|
||||
" Tensor a_scales,"
|
||||
" Tensor b_scales,"
|
||||
" Tensor b_group_scales,"
|
||||
" int b_group_size,"
|
||||
" Tensor expert_offsets,"
|
||||
" Tensor problem_sizes,"
|
||||
" Tensor a_strides,"
|
||||
" Tensor b_strides,"
|
||||
" Tensor c_strides,"
|
||||
" Tensor group_scale_strides,"
|
||||
" str? maybe_schedule"
|
||||
") -> ()");
|
||||
|
||||
ops.def(
|
||||
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
|
||||
"Tensor)");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
|
||||
ops.impl("get_cutlass_batched_moe_mm_data",
|
||||
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
|
||||
|
||||
// FP4/NVFP4 ops
|
||||
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
|
||||
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
|
||||
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
|
||||
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
|
||||
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
|
||||
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
|
||||
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
|
||||
|
||||
// W4A8 ops: impl registrations are in the source files
|
||||
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||
TORCH_BOX(&cutlass_group_gemm_supported));
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
|
||||
ops.impl("cutlass_scaled_mm_supports_fp4",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
#include "gpt_oss_router_gemm.cuh"
|
||||
|
||||
void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB,
|
||||
__nv_bfloat16* gC, __nv_bfloat16* bias,
|
||||
int batch_size, int output_features,
|
||||
int input_features, cudaStream_t stream) {
|
||||
static int const WARP_TILE_M = 16;
|
||||
static int const TILE_M = WARP_TILE_M;
|
||||
static int const TILE_N = 8;
|
||||
static int const TILE_K = 64;
|
||||
static int const STAGES = 16;
|
||||
static int const STAGE_UNROLL = 4;
|
||||
static bool const PROFILE = false;
|
||||
|
||||
CUtensorMap weight_map{};
|
||||
CUtensorMap activation_map{};
|
||||
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features};
|
||||
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
|
||||
uint32_t box_size[rank] = {TILE_K, TILE_M};
|
||||
uint32_t elem_stride[rank] = {1, 1};
|
||||
|
||||
CUresult res = cuTensorMapEncodeTiled(
|
||||
&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank,
|
||||
gB, size, stride, box_size, elem_stride,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
TORCH_CHECK(res == CUDA_SUCCESS,
|
||||
"cuTensorMapEncodeTiled failed for weight_map, error code=",
|
||||
static_cast<int>(res));
|
||||
|
||||
size[1] = batch_size;
|
||||
box_size[1] = TILE_N;
|
||||
|
||||
res = cuTensorMapEncodeTiled(
|
||||
&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
rank, gA, size, stride, box_size, elem_stride,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
TORCH_CHECK(res == CUDA_SUCCESS,
|
||||
"cuTensorMapEncodeTiled failed for activation_map, error code=",
|
||||
static_cast<int>(res));
|
||||
|
||||
int smem_size = STAGES * STAGE_UNROLL *
|
||||
(TILE_M * TILE_K * sizeof(__nv_bfloat16) +
|
||||
TILE_N * TILE_K * sizeof(__nv_bfloat16));
|
||||
|
||||
gpuErrChk(cudaFuncSetAttribute(
|
||||
gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
|
||||
STAGE_UNROLL, PROFILE>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
|
||||
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
|
||||
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
|
||||
|
||||
dim3 grid(tiles_m, tiles_n);
|
||||
dim3 block(384);
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
config.gridDim = grid;
|
||||
config.blockDim = block;
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
config.attrs = attrs;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.numAttrs = 1;
|
||||
|
||||
cudaLaunchKernelEx(
|
||||
&config,
|
||||
&gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
|
||||
STAGE_UNROLL, PROFILE>,
|
||||
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map,
|
||||
activation_map, nullptr);
|
||||
}
|
||||
|
||||
void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output,
|
||||
torch::Tensor input, torch::Tensor weight,
|
||||
torch::Tensor bias) {
|
||||
auto const batch_size = input.size(0);
|
||||
auto const input_dim = input.size(1);
|
||||
auto const output_dim = weight.size(0);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::BFloat16) {
|
||||
launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(),
|
||||
(__nv_bfloat16*)weight.data_ptr(),
|
||||
(__nv_bfloat16*)output.mutable_data_ptr(),
|
||||
(__nv_bfloat16*)bias.data_ptr(), batch_size,
|
||||
output_dim, input_dim, stream);
|
||||
} else {
|
||||
throw std::invalid_argument("Unsupported dtype, only supports bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
|
||||
torch::Tensor weight, torch::Tensor bias) {
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D");
|
||||
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
|
||||
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");
|
||||
TORCH_CHECK(input.sizes()[1] == weight.sizes()[1],
|
||||
"input.size(1) must match weight.size(1)");
|
||||
TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0],
|
||||
"weight.size(0) must match bias.size(0)");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"input tensor must be bfloat16");
|
||||
TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16,
|
||||
"weight tensor must be bfloat16");
|
||||
TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16,
|
||||
"bias tensor must be bfloat16");
|
||||
gpt_oss_router_gemm_cuda_forward(output, input, weight, bias);
|
||||
}
|
||||
@@ -1,447 +0,0 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cuda_bf16.h"
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
|
||||
#include "cuda_pipeline.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda/barrier>
|
||||
#include <cuda/std/utility>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
using barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
namespace cde = cuda::device::experimental;
|
||||
namespace ptx = cuda::ptx;
|
||||
|
||||
#define gpuErrChk(ans) \
|
||||
{ \
|
||||
gpuAssert((ans), __FILE__, __LINE__); \
|
||||
}
|
||||
|
||||
inline void gpuAssert(cudaError_t code, char const* file, int line,
|
||||
bool abort = true) {
|
||||
if (code != cudaSuccess) {
|
||||
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
|
||||
line);
|
||||
if (abort) {
|
||||
throw std::runtime_error(cudaGetErrorString(code));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
__device__ uint64_t gclock64() {
|
||||
unsigned long long int rv;
|
||||
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv));
|
||||
return rv;
|
||||
}
|
||||
|
||||
__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) {
|
||||
int dst;
|
||||
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
|
||||
: "=r"(dst)
|
||||
: "r"(smem_ptr));
|
||||
int* rvi = reinterpret_cast<int*>(&rv[0]);
|
||||
rvi[0] = dst;
|
||||
}
|
||||
|
||||
__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) {
|
||||
int x, y;
|
||||
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||
: "=r"(x), "=r"(y)
|
||||
: "r"(smem_ptr));
|
||||
|
||||
int* rvi = reinterpret_cast<int*>(&rv[0]);
|
||||
rvi[0] = x;
|
||||
rvi[1] = y;
|
||||
}
|
||||
|
||||
__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) {
|
||||
int x, y, z, w;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
|
||||
: "r"(smem_ptr));
|
||||
int* rvi = reinterpret_cast<int*>(&rv[0]);
|
||||
rvi[0] = x;
|
||||
rvi[1] = y;
|
||||
rvi[2] = z;
|
||||
rvi[3] = w;
|
||||
}
|
||||
|
||||
__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2],
|
||||
float c[4]) {
|
||||
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
|
||||
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
|
||||
float const* C = reinterpret_cast<float const*>(&c[0]);
|
||||
float* D = reinterpret_cast<float*>(&d[0]);
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]),
|
||||
"f"(C[3]));
|
||||
}
|
||||
|
||||
__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4],
|
||||
float c[4]) {
|
||||
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
|
||||
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
|
||||
float const* C = reinterpret_cast<float const*>(&c[0]);
|
||||
float* D = reinterpret_cast<float*>(&d[0]);
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
||||
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
|
||||
}
|
||||
|
||||
__device__ void bar_wait(uint32_t bar_ptr, int phase) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n" ::"r"(bar_ptr),
|
||||
"r"(phase));
|
||||
}
|
||||
|
||||
__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) {
|
||||
uint32_t success;
|
||||
#ifdef INTERNAL
|
||||
asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory");
|
||||
#endif
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
|
||||
"selp.b32 %0, 1, 0, P1; \n\t"
|
||||
"}"
|
||||
: "=r"(success)
|
||||
: "r"(bar_ptr), "r"(phase));
|
||||
return success;
|
||||
}
|
||||
|
||||
__device__ uint32_t elect_one_sync() {
|
||||
uint32_t pred = 0;
|
||||
uint32_t laneid = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b32 %%rx;\n"
|
||||
".reg .pred %%px;\n"
|
||||
" elect.sync %%rx|%%px, %2;\n"
|
||||
"@%%px mov.s32 %1, 1;\n"
|
||||
" mov.s32 %0, %%rx;\n"
|
||||
"}\n"
|
||||
: "+r"(laneid), "+r"(pred)
|
||||
: "r"(0xFFFFFFFF));
|
||||
return pred;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct Profile {
|
||||
uint64_t start;
|
||||
uint64_t weight_load_start;
|
||||
uint64_t act_load_start;
|
||||
uint64_t compute_start;
|
||||
uint64_t complete;
|
||||
};
|
||||
|
||||
template <int WARP_TILE_M, int TILE_M, int TILE_N, int TILE_K, int STAGES,
|
||||
int STAGE_UNROLL, bool PROFILE>
|
||||
__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel(
|
||||
__nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations,
|
||||
__nv_bfloat16* bias, int M, int N, int K,
|
||||
const __grid_constant__ CUtensorMap weight_map,
|
||||
const __grid_constant__ CUtensorMap activation_map,
|
||||
Profile* profile = nullptr) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
|
||||
if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0)
|
||||
profile[blockIdx.x].start = gclock64();
|
||||
|
||||
extern __shared__ __align__(128) char smem[];
|
||||
|
||||
__nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0];
|
||||
__nv_bfloat16* sh_activations =
|
||||
(__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K *
|
||||
sizeof(__nv_bfloat16)];
|
||||
|
||||
#pragma nv_diag_suppress static_var_with_dynamic_init
|
||||
__shared__ barrier bar_wt_ready[STAGES];
|
||||
__shared__ barrier bar_act_ready[STAGES];
|
||||
__shared__ barrier bar_data_consumed[STAGES];
|
||||
|
||||
__shared__ float4 reduction_buffer[128];
|
||||
|
||||
__shared__ nv_bfloat16 sh_bias[TILE_M];
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < STAGES; i++) {
|
||||
init(&bar_wt_ready[i], 1);
|
||||
init(&bar_act_ready[i], 1);
|
||||
init(&bar_data_consumed[i], 32);
|
||||
}
|
||||
ptx::fence_proxy_async(ptx::space_shared);
|
||||
asm volatile("prefetch.tensormap [%0];"
|
||||
:
|
||||
: "l"(reinterpret_cast<uint64_t>(&weight_map))
|
||||
: "memory");
|
||||
asm volatile("prefetch.tensormap [%0];"
|
||||
:
|
||||
: "l"(reinterpret_cast<uint64_t>(&activation_map))
|
||||
: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int lane_id = threadIdx.x % 32;
|
||||
|
||||
int phase = 0;
|
||||
|
||||
int mib = blockIdx.x * TILE_M;
|
||||
int ni = blockIdx.y * TILE_N;
|
||||
|
||||
float accum[4];
|
||||
for (int i = 0; i < 4; i++) accum[i] = 0.f;
|
||||
|
||||
int const K_LOOPS_DMA =
|
||||
(K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL));
|
||||
int const K_LOOPS_COMPUTE = K_LOOPS_DMA;
|
||||
|
||||
// Data loading thread
|
||||
if (warp_id >= 4 && elect_one_sync()) {
|
||||
int stage = warp_id % 4;
|
||||
|
||||
bool weight_warp = warp_id < 8;
|
||||
if (!weight_warp) {
|
||||
cudaGridDependencySynchronize();
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
|
||||
for (int ki = 0; ki < K_LOOPS_DMA; ki++) {
|
||||
int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL;
|
||||
|
||||
uint64_t desc_ptr_wt = reinterpret_cast<uint64_t>(&weight_map);
|
||||
uint64_t desc_ptr_act = reinterpret_cast<uint64_t>(&activation_map);
|
||||
|
||||
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
|
||||
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
|
||||
int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16);
|
||||
int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16);
|
||||
|
||||
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
|
||||
|
||||
if (weight_warp)
|
||||
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
|
||||
:
|
||||
: "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt));
|
||||
if (!weight_warp)
|
||||
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
|
||||
:
|
||||
: "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act));
|
||||
|
||||
if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp)
|
||||
profile[blockIdx.x].weight_load_start = gclock64();
|
||||
if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp)
|
||||
profile[blockIdx.x].act_load_start = gclock64();
|
||||
|
||||
for (int i = 0; i < STAGE_UNROLL; i++) {
|
||||
uint32_t smem_ptr_wt = __cvta_generic_to_shared(
|
||||
&sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]);
|
||||
uint32_t crd0 = k + i * TILE_K;
|
||||
uint32_t crd1 = mib;
|
||||
if (weight_warp)
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
|
||||
"tx::bytes [%0], [%1, {%3,%4}], "
|
||||
"[%2];"
|
||||
:
|
||||
: "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0),
|
||||
"r"(crd1)
|
||||
: "memory");
|
||||
|
||||
uint32_t smem_ptr_act = __cvta_generic_to_shared(
|
||||
&sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]);
|
||||
crd0 = k + i * TILE_K;
|
||||
crd1 = ni;
|
||||
if (!weight_warp)
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
|
||||
"tx::bytes [%0], [%1, {%3,%4}], "
|
||||
"[%2];"
|
||||
:
|
||||
: "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act),
|
||||
"r"(crd0), "r"(crd1)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
stage += 4;
|
||||
if (stage >= STAGES) {
|
||||
stage = warp_id % 4;
|
||||
phase ^= 1;
|
||||
}
|
||||
}
|
||||
// Wait for pending loads to be consumed before exiting, to avoid race
|
||||
for (int i = 0; i < (STAGES / 4) - 1; i++) {
|
||||
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
|
||||
stage += 4;
|
||||
if (stage >= STAGES) {
|
||||
stage = warp_id % 4;
|
||||
phase ^= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Compute threads
|
||||
else if (warp_id < 4) {
|
||||
// Sneak the bias load into the compute warps since they're just waiting for
|
||||
// stuff anyway
|
||||
if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x];
|
||||
|
||||
int stage = warp_id;
|
||||
|
||||
int phase = 0;
|
||||
int lane_id_div8 = lane_id / 8;
|
||||
int lane_id_mod8 = lane_id % 8;
|
||||
|
||||
int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0;
|
||||
int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0;
|
||||
|
||||
int row_wt = lane_id_mod8 + lane_row_offset_wt;
|
||||
int row_act = lane_id_mod8;
|
||||
|
||||
int row_offset_wt = (reinterpret_cast<uintptr_t>(sh_weights) / 128) % 8;
|
||||
int row_offset_act = row_offset_wt;
|
||||
|
||||
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
|
||||
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
|
||||
|
||||
bool weight_ready = bar_try_wait(bar_ptr_wt, phase);
|
||||
bool act_ready = bar_try_wait(bar_ptr_act, phase);
|
||||
|
||||
#pragma unroll 2
|
||||
for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) {
|
||||
int next_stage = stage + 4;
|
||||
int next_phase = phase;
|
||||
if (next_stage >= STAGES) {
|
||||
next_stage = warp_id;
|
||||
next_phase ^= 1;
|
||||
}
|
||||
|
||||
while (!weight_ready || !act_ready) {
|
||||
weight_ready = bar_try_wait(bar_ptr_wt, phase);
|
||||
act_ready = bar_try_wait(bar_ptr_act, phase);
|
||||
}
|
||||
|
||||
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0)
|
||||
profile[blockIdx.x].compute_start = gclock64();
|
||||
|
||||
if (ki + 1 < K_LOOPS_COMPUTE) {
|
||||
weight_ready = bar_try_wait(
|
||||
__cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase);
|
||||
act_ready = bar_try_wait(
|
||||
__cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int su = 0; su < STAGE_UNROLL; su++) {
|
||||
__nv_bfloat16* ptr_weights =
|
||||
&sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K];
|
||||
__nv_bfloat16* ptr_act =
|
||||
&sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K];
|
||||
|
||||
#pragma unroll
|
||||
for (int kii = 0; kii < TILE_K / 16; kii++) {
|
||||
__nv_bfloat16 a[8];
|
||||
__nv_bfloat16 b[4];
|
||||
|
||||
int col = 2 * kii + lane_col_offset_wt;
|
||||
int col_sw = ((row_wt + row_offset_wt) % 8) ^ col;
|
||||
|
||||
ldmatrix4(a, __cvta_generic_to_shared(
|
||||
&ptr_weights[row_wt * TILE_K + col_sw * 8]));
|
||||
|
||||
col = 2 * kii + lane_id_div8;
|
||||
col_sw = ((row_act + row_offset_act) % 8) ^ col;
|
||||
|
||||
ldmatrix2(b, __cvta_generic_to_shared(
|
||||
&ptr_act[row_act * TILE_K + 8 * col_sw]));
|
||||
|
||||
HMMA_16816(accum, a, b, accum);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]);
|
||||
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c));
|
||||
|
||||
stage = next_stage;
|
||||
phase = next_phase;
|
||||
}
|
||||
|
||||
float4 accum4;
|
||||
accum4.x = accum[0];
|
||||
accum4.y = accum[1];
|
||||
accum4.z = accum[2];
|
||||
accum4.w = accum[3];
|
||||
reduction_buffer[threadIdx.x] = accum4;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
int mi = mib + warp_id * WARP_TILE_M;
|
||||
int tm = mi + lane_id / 4;
|
||||
int tn = ni + 2 * (lane_id % 4);
|
||||
|
||||
float4 accum1 = reduction_buffer[32 + threadIdx.x];
|
||||
float4 accum2 = reduction_buffer[64 + threadIdx.x];
|
||||
float4 accum3 = reduction_buffer[96 + threadIdx.x];
|
||||
|
||||
accum[0] = accum[0] + accum1.x + accum2.x + accum3.x;
|
||||
accum[1] = accum[1] + accum1.y + accum2.y + accum3.y;
|
||||
accum[2] = accum[2] + accum1.z + accum2.z + accum3.z;
|
||||
accum[3] = accum[3] + accum1.w + accum2.w + accum3.w;
|
||||
|
||||
float bias_lo = __bfloat162float(sh_bias[tm - mib]);
|
||||
float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]);
|
||||
|
||||
if (tn < N && tm < M)
|
||||
output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo);
|
||||
if (tn + 1 < N && tm < M)
|
||||
output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo);
|
||||
if (tn < N && tm + 8 < M)
|
||||
output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi);
|
||||
if (tn + 1 < N && tm + 8 < M)
|
||||
output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi);
|
||||
|
||||
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
|
||||
profile[blockIdx.x].complete = gclock64();
|
||||
}
|
||||
}
|
||||
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
}
|
||||
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# MXFP8
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE4M3fn",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
|
||||
@@ -343,6 +343,8 @@ __global__ void Marlin(
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) {
|
||||
static_assert(group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
@@ -357,9 +359,10 @@ __global__ void Marlin(
|
||||
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
|
||||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
|
||||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
|
||||
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
is_a_8bit || b_type == vllm::kFE4M3fn ||
|
||||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
|
||||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||
@@ -373,7 +376,7 @@ __global__ void Marlin(
|
||||
const int group_size =
|
||||
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
|
||||
const int scales_expert_stride =
|
||||
prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8);
|
||||
const int zp_expert_stride =
|
||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||
@@ -692,9 +695,8 @@ __global__ void Marlin(
|
||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||
|
||||
// Scale sizes/strides without act_order
|
||||
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
constexpr int s_sh_stride =
|
||||
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
@@ -1131,7 +1133,7 @@ __global__ void Marlin(
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
@@ -1140,7 +1142,7 @@ __global__ void Marlin(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
} else if (group_blocks >= b_sh_wr_iters) {
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[1])[0] =
|
||||
reinterpret_cast<int4*>(&frag_s[0])[0];
|
||||
} else {
|
||||
@@ -1341,7 +1343,7 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
|
||||
@@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
} else if (b_type_id == vllm::kFE4M3fn.id() &&
|
||||
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
|
||||
@@ -70,8 +70,4 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
|
||||
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
|
||||
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b);
|
||||
|
||||
// gpt-oss optimized router GEMM kernel for SM90+
|
||||
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
|
||||
torch::Tensor weight, torch::Tensor bias);
|
||||
#endif
|
||||
|
||||
@@ -132,12 +132,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
// DeepSeek V3 optimized router GEMM for SM90+
|
||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// gpt-oss optimized router GEMM kernel for SM90+
|
||||
m.def(
|
||||
"gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, "
|
||||
"Tensor bias) -> ()");
|
||||
m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
68
csrc/ops.h
68
csrc/ops.h
@@ -53,12 +53,12 @@ void paged_attention_v2(
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse);
|
||||
void merge_attn_states(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale = std::nullopt);
|
||||
#ifndef USE_ROCM
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
@@ -114,9 +114,9 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
int64_t numRows, int64_t stride0, int64_t stride1,
|
||||
int64_t topK);
|
||||
|
||||
void large_context_topk(const torch::Tensor& score, torch::Tensor& indices,
|
||||
const torch::Tensor& lengths,
|
||||
std::optional<torch::Tensor> row_starts_opt);
|
||||
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
||||
int64_t max_seq_len);
|
||||
|
||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& weight, torch::Tensor& scale,
|
||||
@@ -143,6 +143,12 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
std::optional<torch::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
bool is_scale_transposed);
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
@@ -152,12 +158,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
|
||||
torch::Tensor& output_block_scale,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& input_global_scale);
|
||||
#endif
|
||||
void persistent_masked_m_silu_mul_quant(
|
||||
const at::Tensor& input, // (E, T, 2*H)
|
||||
const at::Tensor& counts, // (E)
|
||||
@@ -225,44 +225,6 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout);
|
||||
|
||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||
torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||
torch::Tensor& output_scale);
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
1321
csrc/persistent_topk.cuh
Normal file
1321
csrc/persistent_topk.cuh
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,163 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& output_sf,
|
||||
torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
|
||||
torch::Tensor& output_sf,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& input_sf);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
static bool nvfp4_quant_sm_supported() {
|
||||
const int32_t sm = get_sm_version_num();
|
||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||
if (sm >= 100 && sm < 120) return true;
|
||||
#endif
|
||||
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||
if (sm >= 120 && sm < 130) return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||
torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||
torch::Tensor& output_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled nvfp4 quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||
is_sf_swizzled_layout);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
int64_t n = input.size(-1);
|
||||
int64_t m = input.numel() / n;
|
||||
auto device = input.device();
|
||||
|
||||
// Two fp4 values packed into a uint8
|
||||
auto output = torch::empty(
|
||||
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
||||
|
||||
torch::Tensor output_sf;
|
||||
if (is_sf_swizzled_layout) {
|
||||
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
|
||||
output_sf = torch::empty(
|
||||
{sf_m, sf_n},
|
||||
torch::TensorOptions().device(device).dtype(torch::kInt32));
|
||||
} else {
|
||||
output_sf = torch::empty(
|
||||
{m, n / CVT_FP4_SF_VEC_SIZE},
|
||||
torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
||||
}
|
||||
|
||||
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
|
||||
output_sf);
|
||||
return {output, output_sf};
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled nvfp4 experts quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
|
||||
torch::Tensor& input, torch::Tensor& input_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled silu_and_mul nvfp4 experts quantization kernel "
|
||||
"for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
|
||||
}
|
||||
169
csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu
Normal file
169
csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "quant_conversions.cuh"
|
||||
#include "../w8a8/fp8/common.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Logic: one thread block per (token, group) pair
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_transposed,
|
||||
int32_t group_size>
|
||||
__global__ void silu_and_mul_per_block_quant_kernel(
|
||||
scalar_out_t* __restrict__ out, // Output: [num_tokens, hidden_size] in
|
||||
// FP8/INT8
|
||||
float* __restrict__ scales, // Output: [num_tokens, hidden_size /
|
||||
// group_size] or [hidden_size / group_size,
|
||||
// num_tokens]
|
||||
scalar_t const* __restrict__ input, // Input: [num_tokens, hidden_size * 2]
|
||||
float const* scale_ub, // Optional scale upper bound
|
||||
int32_t const hidden_size // Output hidden size (input is 2x this)
|
||||
) {
|
||||
static_assert((group_size & (group_size - 1)) == 0,
|
||||
"group_size must be a power of 2 for correct reduction");
|
||||
|
||||
// Grid: (num_tokens, num_groups)
|
||||
int const token_idx = blockIdx.x;
|
||||
int const group_idx = blockIdx.y;
|
||||
int const tid = threadIdx.x; // tid in [0, group_size)
|
||||
int const num_tokens = gridDim.x;
|
||||
|
||||
// Input layout: [gate || up] concatenated along last dimension
|
||||
int const input_stride = hidden_size * 2;
|
||||
int const group_start = group_idx * group_size;
|
||||
|
||||
// Pointers to this token's data
|
||||
scalar_t const* token_input_gate =
|
||||
input + token_idx * input_stride + group_start;
|
||||
scalar_t const* token_input_up = token_input_gate + hidden_size;
|
||||
scalar_out_t* token_output = out + token_idx * hidden_size + group_start;
|
||||
|
||||
// Scale pointer for this group
|
||||
int const num_groups = gridDim.y;
|
||||
float* group_scale_ptr = is_scale_transposed
|
||||
? scales + group_idx * num_tokens + token_idx
|
||||
: scales + token_idx * num_groups + group_idx;
|
||||
|
||||
// Shared memory for reduction (compile-time sized)
|
||||
__shared__ float shared_max[group_size];
|
||||
|
||||
// Step 1: Each thread loads one element, computes SiLU, stores in register
|
||||
float gate = static_cast<float>(token_input_gate[tid]);
|
||||
float up = static_cast<float>(token_input_up[tid]);
|
||||
|
||||
// Compute SiLU(gate) * up
|
||||
float sigmoid_gate = 1.0f / (1.0f + expf(-gate));
|
||||
float silu_gate = gate * sigmoid_gate;
|
||||
float result = silu_gate * up; // Keep in register
|
||||
|
||||
// Step 2: Reduce to find group max
|
||||
shared_max[tid] = fabsf(result);
|
||||
__syncthreads();
|
||||
|
||||
// Power-of-2 reduction (group_size guaranteed to be power of 2)
|
||||
#pragma unroll
|
||||
for (int stride = group_size / 2; stride > 0; stride >>= 1) {
|
||||
if (tid < stride) {
|
||||
shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Step 3: Compute scale (thread 0), broadcast via shared memory
|
||||
if (tid == 0) {
|
||||
float group_max = shared_max[0];
|
||||
|
||||
float const quant_range = quant_type_max_v<scalar_out_t>;
|
||||
float group_scale = group_max / quant_range;
|
||||
|
||||
// Apply scale upper bound if provided
|
||||
if (scale_ub != nullptr) {
|
||||
group_scale = fminf(group_scale, *scale_ub);
|
||||
}
|
||||
|
||||
// Use minimum safe scaling factor
|
||||
group_scale = fmaxf(group_scale, min_scaling_factor<scalar_out_t>::val());
|
||||
|
||||
// Store scale to global memory
|
||||
*group_scale_ptr = group_scale;
|
||||
|
||||
// Reuse shared_max[0] to broadcast scale
|
||||
shared_max[0] = group_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float group_scale = shared_max[0];
|
||||
|
||||
// Step 4: Quantize and write output
|
||||
token_output[tid] =
|
||||
vllm::ScaledQuant<scalar_out_t, false>::quant_fn(result, group_scale);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
bool is_scale_transposed) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
TORCH_CHECK(
|
||||
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
|
||||
"Input must be FP16 or BF16");
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
|
||||
TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
"Unsupported group size: ", group_size);
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
}
|
||||
|
||||
int32_t hidden_size = out.size(-1);
|
||||
auto num_tokens = input.size(0);
|
||||
int32_t num_groups = hidden_size / group_size;
|
||||
|
||||
TORCH_CHECK(input.size(-1) == hidden_size * 2,
|
||||
"input last dim must be 2x output hidden_size");
|
||||
TORCH_CHECK(hidden_size % group_size == 0,
|
||||
"hidden_size must be divisible by group_size");
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(num_tokens, num_groups);
|
||||
dim3 block(group_size);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
|
||||
using scalar_in_t = scalar_t;
|
||||
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
|
||||
using scalar_out_t = scalar_t;
|
||||
|
||||
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
|
||||
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
|
||||
vllm::silu_and_mul_per_block_quant_kernel<
|
||||
scalar_in_t, scalar_out_t, transpose_scale, gs>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_out_t>(),
|
||||
scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||
// TODO(luka/varun):refactor common.cuh to use this file instead
|
||||
#include "quantization/w8a8/fp8/common.cuh"
|
||||
#include "../w8a8/fp8/common.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
|
||||
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# MXFP8
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE4M3fn",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
|
||||
@@ -591,6 +591,9 @@ torch::Tensor marlin_gemm(
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
} else if (b_type_id == vllm::kFE4M3fn.id() &&
|
||||
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
|
||||
@@ -327,6 +327,9 @@ __global__ void Marlin(
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (s_type == vllm::kFE8M0fnu) {
|
||||
// MXFP8: FP8 weights with e8m0 microscaling block scales
|
||||
static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
@@ -334,6 +337,7 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
constexpr bool is_a_8bit = a_type.size_bits() == 8;
|
||||
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
|
||||
if constexpr (!is_a_8bit) {
|
||||
static_assert(std::is_same<scalar_t, c_scalar_t>::value);
|
||||
}
|
||||
@@ -343,7 +347,7 @@ __global__ void Marlin(
|
||||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
is_a_8bit || b_type == vllm::kFE4M3fn ||
|
||||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
|
||||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||
@@ -555,9 +559,8 @@ __global__ void Marlin(
|
||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||
|
||||
// Scale sizes/strides without act_order
|
||||
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
constexpr int s_sh_stride =
|
||||
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
@@ -997,7 +1000,7 @@ __global__ void Marlin(
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
@@ -1006,7 +1009,7 @@ __global__ void Marlin(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
} else if (group_blocks >= b_sh_wr_iters) {
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[1])[0] =
|
||||
reinterpret_cast<int4*>(&frag_s[0])[0];
|
||||
} else {
|
||||
@@ -1207,7 +1210,7 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||
#include "quantization/utils.cuh"
|
||||
#include "../../utils.cuh"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
|
||||
@@ -564,8 +564,9 @@ template <int kNumThreadsPerBlock, bool useRadixSort,
|
||||
bool multipleBlocksPerRow = false, bool mergeBlocks = false>
|
||||
static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
const float* logits, const int* seqLens, int* outIndices, int stride0,
|
||||
int stride1, const int topK, int next_n, float* outLogits = nullptr,
|
||||
const int numBlocksToMerge = 0, const int* indices = nullptr) {
|
||||
int stride1, const int topK, int next_n, int seqLensIs2D = 0,
|
||||
float* outLogits = nullptr, const int numBlocksToMerge = 0,
|
||||
const int* indices = nullptr) {
|
||||
// The number of bins in the histogram.
|
||||
static constexpr int kNumBins = 2048;
|
||||
|
||||
@@ -574,8 +575,16 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
|
||||
|
||||
// The range of logits within the row.
|
||||
int rowStart = 0;
|
||||
int seq_len = seqLens[rowIdx / next_n];
|
||||
int rowEnd = max(0, seq_len - next_n + (rowIdx % next_n) + 1);
|
||||
int batch_idx = rowIdx / next_n;
|
||||
int next_n_idx = rowIdx % next_n;
|
||||
// seqLensIs2D=0: 1D seqLens — all rows in a batch share the same seq_len;
|
||||
// kernel computes per-row effective length via offset.
|
||||
// seqLensIs2D=1: 2D seqLens — each logit row has its own pre-computed
|
||||
// effective length (flat index rowIdx = b*next_n + j maps
|
||||
// directly to seqLens[b, j] in C-contiguous layout).
|
||||
int seq_len = seqLensIs2D ? seqLens[rowIdx] : seqLens[batch_idx];
|
||||
int rowEnd =
|
||||
seqLensIs2D ? max(0, seq_len) : max(0, seq_len - next_n + next_n_idx + 1);
|
||||
|
||||
// Local pointers to this block
|
||||
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
|
||||
@@ -653,6 +662,11 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const auto numColumns = logits.size(1);
|
||||
|
||||
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
|
||||
// effective seq_len. False if seqLens is 1D (B,): all rows in a batch share
|
||||
// the same seq_len and the kernel computes the per-row offset itself.
|
||||
int seqLensIs2D = seqLens.dim() == 2 ? 1 : 0;
|
||||
|
||||
if (numColumns < kSortingAlgorithmThreshold) {
|
||||
// Use insertion sort
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, false>
|
||||
@@ -660,7 +674,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n));
|
||||
static_cast<int>(next_n), seqLensIs2D);
|
||||
} else if (numColumns < kSplitWorkThreshold) {
|
||||
// From this threshold, use radix sort instead
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlock, true>
|
||||
@@ -668,7 +682,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n));
|
||||
static_cast<int>(next_n), seqLensIs2D);
|
||||
} else {
|
||||
// Long sequences are run in two steps
|
||||
constexpr auto multipleBlocksPerRowConfig = 10;
|
||||
@@ -686,15 +700,16 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
|
||||
logits.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
outIndicesAux.data_ptr<int>(), static_cast<int>(stride0),
|
||||
static_cast<int>(stride1), static_cast<int>(topK),
|
||||
static_cast<int>(next_n), outLogitsAux.data_ptr<float>());
|
||||
static_cast<int>(next_n), seqLensIs2D,
|
||||
outLogitsAux.data_ptr<float>());
|
||||
|
||||
constexpr int kNumThreadsPerBlockMerge = 1024;
|
||||
vllm::topKPerRowDecode<kNumThreadsPerBlockMerge, true, false, true>
|
||||
<<<numRows, kNumThreadsPerBlockMerge, topK * sizeof(int32_t), stream>>>(
|
||||
outLogitsAux.data_ptr<float>(), seqLens.data_ptr<int>(),
|
||||
indices.data_ptr<int>(), multipleBlocksPerRowConfig * topK, 1,
|
||||
static_cast<int>(topK), static_cast<int>(next_n), nullptr,
|
||||
multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
||||
static_cast<int>(topK), static_cast<int>(next_n), seqLensIs2D,
|
||||
nullptr, multipleBlocksPerRowConfig, outIndicesAux.data_ptr<int>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
487
csrc/topk.cu
487
csrc/topk.cu
@@ -1,373 +1,156 @@
|
||||
// Portions of this file are adapted from SGLang PR:
|
||||
// https://github.com/sgl-project/sglang/pull/11194
|
||||
// and
|
||||
// https://github.com/sgl-project/sglang/pull/17747
|
||||
// Persistent TopK kernel for DeepSeek V3 sparse attention indexer.
|
||||
// See persistent_topk.cuh for kernel implementation.
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include "persistent_topk.cuh"
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
|
||||
torch::Tensor& output, torch::Tensor& workspace, int64_t k,
|
||||
int64_t max_seq_len) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
|
||||
TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
|
||||
TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
|
||||
TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
|
||||
TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
|
||||
TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
|
||||
TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
|
||||
TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
|
||||
"lengths must be 1D or 2D");
|
||||
TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
||||
|
||||
constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k
|
||||
constexpr int kThreadsPerBlock = 1024; // Threads per block
|
||||
const int64_t num_rows = logits.size(0);
|
||||
const int64_t stride = logits.size(1);
|
||||
|
||||
// Shared memory budget
|
||||
#if defined(USE_ROCM)
|
||||
constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB
|
||||
#else
|
||||
// Reduced from 128KB to 32KB to improve occupancy.
|
||||
// Each radix pass needs at most ~TopK candidates in the threshold bin,
|
||||
// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient.
|
||||
constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes)
|
||||
#endif
|
||||
TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
|
||||
TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
|
||||
"output size mismatch");
|
||||
namespace P = vllm::persistent;
|
||||
|
||||
struct FastTopKParams {
|
||||
const float* __restrict__ input; // [batch, seq_len] Logits
|
||||
const int32_t* __restrict__ row_starts; // [batch] Offset into each row
|
||||
// (optional)
|
||||
int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices
|
||||
int32_t* __restrict__ lengths; // [batch] Sequence lengths per row
|
||||
int64_t input_stride; // Stride between rows
|
||||
};
|
||||
TORCH_CHECK(k == P::TopK, "k must be 2048");
|
||||
TORCH_CHECK(k <= stride, "k out of range");
|
||||
|
||||
__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t {
|
||||
uint32_t bits = __float_as_uint(x);
|
||||
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
|
||||
}
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
|
||||
__half h = __float2half_rn(x);
|
||||
uint16_t bits = __half_as_ushort(h);
|
||||
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
|
||||
: static_cast<uint16_t>(bits | 0x8000);
|
||||
return static_cast<uint8_t>(key >> 8);
|
||||
}
|
||||
|
||||
__device__ void naive_topk_cuda(const float* __restrict__ logits,
|
||||
int32_t* __restrict__ output_indices,
|
||||
int32_t seq_len) {
|
||||
const int thread_id = threadIdx.x;
|
||||
for (int i = thread_id; i < TopK; i += kThreadsPerBlock) {
|
||||
output_indices[i] = (i < seq_len) ? i : -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Adapted from:
|
||||
// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87
|
||||
// by: DarkSharpness
|
||||
// which at the same time is an optimized topk kernel copied from tilelang
|
||||
// kernel
|
||||
__device__ void fast_topk_cuda_tl(
|
||||
const float* __restrict__ logits, // Input logits [seq_len]
|
||||
int* __restrict__ output_indices, // Output top-k indices [TopK]
|
||||
int logits_offset, // Starting offset in logits array
|
||||
int seq_len) // Number of valid logits to process
|
||||
{
|
||||
constexpr int RADIX = 256;
|
||||
constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int));
|
||||
|
||||
alignas(128) __shared__ int shared_histogram[2][RADIX + 128];
|
||||
alignas(128) __shared__ int shared_output_count;
|
||||
alignas(128) __shared__ int shared_threshold_bin;
|
||||
alignas(128) __shared__ int shared_buffered_count[2];
|
||||
|
||||
extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS];
|
||||
|
||||
const int thread_id = threadIdx.x;
|
||||
int remaining_k = TopK;
|
||||
|
||||
// Pass 0: Build coarse 8-bit histogram using FP16 high bits
|
||||
if (thread_id < RADIX + 1) {
|
||||
shared_histogram[0][thread_id] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
|
||||
const auto bin = convert_to_uint8(logits[idx + logits_offset]);
|
||||
::atomicAdd(&shared_histogram[0][bin], 1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong
|
||||
// buffers
|
||||
auto compute_cumulative_sum = [&]() {
|
||||
static_assert(1 << 8 == RADIX,
|
||||
"Radix must be 256 for 8 unrolled iterations");
|
||||
#pragma unroll 8
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (C10_LIKELY(thread_id < RADIX)) {
|
||||
const int stride = 1 << i;
|
||||
const int src_buffer = i & 1;
|
||||
const int dst_buffer = src_buffer ^ 1;
|
||||
|
||||
int value = shared_histogram[src_buffer][thread_id];
|
||||
if (thread_id < RADIX - stride) {
|
||||
value += shared_histogram[src_buffer][thread_id + stride];
|
||||
}
|
||||
shared_histogram[dst_buffer][thread_id] = value;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
compute_cumulative_sum();
|
||||
|
||||
// Find threshold bin where cumsum crosses remaining_k
|
||||
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
|
||||
shared_histogram[0][thread_id + 1] <= remaining_k) {
|
||||
shared_threshold_bin = thread_id;
|
||||
shared_buffered_count[0] = 0;
|
||||
shared_output_count = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const int threshold_bin = shared_threshold_bin;
|
||||
remaining_k -= shared_histogram[0][threshold_bin + 1];
|
||||
|
||||
// Early exit if threshold bin perfectly matches remaining_k
|
||||
if (remaining_k == 0) {
|
||||
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
|
||||
const int bin = convert_to_uint8(logits[idx + logits_offset]);
|
||||
if (bin > threshold_bin) {
|
||||
const int output_pos = ::atomicAdd(&shared_output_count, 1);
|
||||
output_indices[output_pos] = idx;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
return;
|
||||
static int num_sms = 0;
|
||||
static int max_smem_per_block = 0;
|
||||
if (num_sms == 0) {
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);
|
||||
cudaDeviceGetAttribute(&max_smem_per_block,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
}
|
||||
|
||||
// Prepare for refinement passes: Process threshold bin
|
||||
__syncthreads();
|
||||
if (thread_id < RADIX + 1) {
|
||||
shared_histogram[0][thread_id] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Scan all elements and:
|
||||
// 1. Write indices > threshold_bin to output
|
||||
// 2. Buffer indices == threshold_bin for refinement
|
||||
// 3. Build histogram for next refinement pass (fused optimization)
|
||||
for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
|
||||
const float logit_value = logits[idx + logits_offset];
|
||||
const int bin = convert_to_uint8(logit_value);
|
||||
|
||||
if (bin > threshold_bin) {
|
||||
// in top-k, write to output
|
||||
const int output_pos = ::atomicAdd(&shared_output_count, 1);
|
||||
output_indices[output_pos] = idx;
|
||||
} else if (bin == threshold_bin) {
|
||||
// Candidate for top-k, needs refinement
|
||||
const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1);
|
||||
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
|
||||
buffered_indices[0][buffer_pos] = idx;
|
||||
// Fused: Build histogram for next pass
|
||||
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
|
||||
const int next_bin = (fp32_bits >> 24) & 0xFF;
|
||||
::atomicAdd(&shared_histogram[0][next_bin], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ============================================================================
|
||||
// Passes 1-4: Refine using 8-bit passes over FP32 bits
|
||||
// ============================================================================
|
||||
// FP32 bits [31:0] split into 4 bytes processed MSB-first:
|
||||
// Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4:
|
||||
// bits [7:0]
|
||||
#pragma unroll 4
|
||||
for (int pass = 0; pass < 4; ++pass) {
|
||||
__shared__ int shared_final_k; // For final pass: remaining slots to fill
|
||||
const int src_buffer = pass % 2;
|
||||
const int dst_buffer = src_buffer ^ 1;
|
||||
|
||||
// Clamp buffered count to prevent overflow
|
||||
const int raw_buffered = shared_buffered_count[src_buffer];
|
||||
const int num_buffered =
|
||||
(raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS;
|
||||
|
||||
compute_cumulative_sum();
|
||||
|
||||
// Find threshold bin for this pass
|
||||
if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
|
||||
shared_histogram[0][thread_id + 1] <= remaining_k) {
|
||||
shared_threshold_bin = thread_id;
|
||||
shared_buffered_count[dst_buffer] = 0;
|
||||
shared_final_k = remaining_k - shared_histogram[0][thread_id + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const int threshold_bin = shared_threshold_bin;
|
||||
remaining_k -= shared_histogram[0][threshold_bin + 1];
|
||||
|
||||
// Bit offset for this pass: 24, 16, 8, 0
|
||||
const int bit_offset = 24 - pass * 8;
|
||||
|
||||
// Early exit if threshold bin perfectly matches
|
||||
if (remaining_k == 0) {
|
||||
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
|
||||
const int idx = buffered_indices[src_buffer][i];
|
||||
const uint32_t fp32_bits =
|
||||
convert_to_uint32_v2(logits[idx + logits_offset]);
|
||||
const int bin = (fp32_bits >> bit_offset) & 0xFF;
|
||||
if (bin > threshold_bin) {
|
||||
const int output_pos = ::atomicAdd(&shared_output_count, 1);
|
||||
output_indices[output_pos] = idx;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
break;
|
||||
}
|
||||
|
||||
// Continue refinement
|
||||
__syncthreads();
|
||||
if (thread_id < RADIX + 1) {
|
||||
shared_histogram[0][thread_id] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
|
||||
const int idx = buffered_indices[src_buffer][i];
|
||||
const float logit_value = logits[idx + logits_offset];
|
||||
const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
|
||||
const int bin = (fp32_bits >> bit_offset) & 0xFF;
|
||||
|
||||
if (bin > threshold_bin) {
|
||||
// Definitely in top-k
|
||||
const int output_pos = ::atomicAdd(&shared_output_count, 1);
|
||||
output_indices[output_pos] = idx;
|
||||
} else if (bin == threshold_bin) {
|
||||
if (pass == 3) {
|
||||
// Final pass (bits [7:0]): No more refinement possible
|
||||
// Fill remaining slots in reverse order to maintain descending order
|
||||
const int slot = ::atomicAdd(&shared_final_k, -1);
|
||||
if (slot > 0) {
|
||||
output_indices[TopK - slot] = idx;
|
||||
}
|
||||
} else {
|
||||
// Buffer for next pass and build next histogram
|
||||
const int buffer_pos =
|
||||
::atomicAdd(&shared_buffered_count[dst_buffer], 1);
|
||||
if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
|
||||
buffered_indices[dst_buffer][buffer_pos] = idx;
|
||||
// Fused: Build histogram for next pass
|
||||
const int next_bit_offset = bit_offset - 8;
|
||||
const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF;
|
||||
::atomicAdd(&shared_histogram[0][next_bin], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel(
|
||||
const FastTopKParams params) {
|
||||
const auto& [input, row_starts, indices, lengths, input_stride] = params;
|
||||
const uint64_t batch_idx = blockIdx.x;
|
||||
const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx];
|
||||
const int seq_len = lengths[batch_idx];
|
||||
int* output_indices = indices + batch_idx * TopK;
|
||||
const float* logits = input + batch_idx * input_stride;
|
||||
|
||||
if (seq_len <= TopK) {
|
||||
// Shortcut: All elements are in top-k
|
||||
return naive_topk_cuda(logits, output_indices, seq_len);
|
||||
if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
|
||||
cudaError_t status = vllm::FilteredTopKRaggedTransform<float, int32_t>(
|
||||
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
|
||||
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
|
||||
static_cast<uint32_t>(k), static_cast<uint32_t>(stride), stream);
|
||||
TORCH_CHECK(status == cudaSuccess,
|
||||
"FilteredTopK failed: ", cudaGetErrorString(status));
|
||||
} else {
|
||||
return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len);
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
|
||||
TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");
|
||||
|
||||
FastTopKParams get_params(
|
||||
const at::Tensor& score, const at::Tensor& lengths,
|
||||
std::optional<at::Tensor> row_starts_opt = std::nullopt,
|
||||
std::optional<at::Tensor> indices_opt = std::nullopt) {
|
||||
const int64_t batch_size = score.size(0);
|
||||
// Smem cap: smaller smem → more CTAs/group → more per-row parallelism for
|
||||
// large path. Empirically tuned.
|
||||
int effective_max_smem;
|
||||
if (num_rows <= 4) {
|
||||
effective_max_smem =
|
||||
std::min(max_smem_per_block, static_cast<int>(P::kSmemMedium));
|
||||
} else if (num_rows <= 8) {
|
||||
constexpr int kSmemCapMedium = 48 * 1024;
|
||||
effective_max_smem = std::min(max_smem_per_block, kSmemCapMedium);
|
||||
} else {
|
||||
effective_max_smem = max_smem_per_block;
|
||||
}
|
||||
|
||||
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1,
|
||||
"score must be 2D with contiguous rows");
|
||||
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() &&
|
||||
lengths.size(0) == batch_size,
|
||||
"lengths must be 1D contiguous with size matching batch");
|
||||
size_t available_for_ordered =
|
||||
static_cast<size_t>(effective_max_smem) - P::kFixedSmemLarge;
|
||||
uint32_t max_chunk_elements =
|
||||
static_cast<uint32_t>(available_for_ordered / sizeof(uint32_t));
|
||||
|
||||
const int32_t* row_starts_ptr = nullptr;
|
||||
if (row_starts_opt.has_value()) {
|
||||
const auto& row_starts = *row_starts_opt;
|
||||
TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size,
|
||||
"row_starts must be 1D with size matching batch");
|
||||
row_starts_ptr = row_starts.data_ptr<int32_t>();
|
||||
uint32_t vec_size = 1;
|
||||
if (stride % 4 == 0)
|
||||
vec_size = 4;
|
||||
else if (stride % 2 == 0)
|
||||
vec_size = 2;
|
||||
|
||||
max_chunk_elements = (max_chunk_elements / vec_size) * vec_size;
|
||||
uint32_t min_chunk = vec_size * P::kThreadsPerBlock;
|
||||
if (max_chunk_elements < min_chunk) max_chunk_elements = min_chunk;
|
||||
|
||||
uint32_t ctas_per_group =
|
||||
(static_cast<uint32_t>(stride) + max_chunk_elements - 1) /
|
||||
max_chunk_elements;
|
||||
uint32_t chunk_size =
|
||||
(static_cast<uint32_t>(stride) + ctas_per_group - 1) / ctas_per_group;
|
||||
chunk_size = ((chunk_size + vec_size - 1) / vec_size) * vec_size;
|
||||
if (chunk_size > max_chunk_elements) chunk_size = max_chunk_elements;
|
||||
|
||||
size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t);
|
||||
if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium;
|
||||
|
||||
int occupancy = 1;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock,
|
||||
smem_size);
|
||||
if (occupancy < 1) occupancy = 1;
|
||||
|
||||
uint32_t max_resident_ctas = static_cast<uint32_t>(num_sms) * occupancy;
|
||||
uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group,
|
||||
static_cast<uint32_t>(num_rows));
|
||||
if (num_groups == 0) num_groups = 1;
|
||||
uint32_t total_ctas = num_groups * ctas_per_group;
|
||||
|
||||
size_t state_bytes = num_groups * sizeof(P::RadixRowState);
|
||||
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
|
||||
"workspace too small, need ", state_bytes, " bytes");
|
||||
|
||||
P::PersistentTopKParams params;
|
||||
params.input = logits.data_ptr<float>();
|
||||
params.output = output.data_ptr<int32_t>();
|
||||
params.lengths = lengths.data_ptr<int32_t>();
|
||||
params.num_rows = static_cast<uint32_t>(num_rows);
|
||||
params.stride = static_cast<uint32_t>(stride);
|
||||
params.chunk_size = chunk_size;
|
||||
params.row_states =
|
||||
reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
|
||||
params.ctas_per_group = ctas_per_group;
|
||||
params.max_seq_len = static_cast<uint32_t>(max_seq_len);
|
||||
|
||||
#define LAUNCH_PERSISTENT(VS) \
|
||||
do { \
|
||||
auto kernel = &P::persistent_topk_kernel<VS>; \
|
||||
cudaError_t err = cudaFuncSetAttribute( \
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
|
||||
TORCH_CHECK(err == cudaSuccess, \
|
||||
"Failed to set smem: ", cudaGetErrorString(err)); \
|
||||
kernel<<<total_ctas, P::kThreadsPerBlock, smem_size, stream>>>(params); \
|
||||
} while (0)
|
||||
|
||||
if (vec_size == 4) {
|
||||
LAUNCH_PERSISTENT(4);
|
||||
} else if (vec_size == 2) {
|
||||
LAUNCH_PERSISTENT(2);
|
||||
} else {
|
||||
LAUNCH_PERSISTENT(1);
|
||||
}
|
||||
#undef LAUNCH_PERSISTENT
|
||||
}
|
||||
|
||||
int32_t* indices_ptr = nullptr;
|
||||
if (indices_opt.has_value()) {
|
||||
const auto& indices = *indices_opt;
|
||||
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() &&
|
||||
indices.size(0) == batch_size && indices.size(1) == TopK,
|
||||
"indices must be 2D contiguous [batch, TopK]");
|
||||
indices_ptr = indices.data_ptr<int32_t>();
|
||||
}
|
||||
|
||||
return FastTopKParams{
|
||||
.input = score.data_ptr<float>(),
|
||||
.row_starts = row_starts_ptr,
|
||||
.indices = indices_ptr,
|
||||
.lengths = lengths.data_ptr<int32_t>(),
|
||||
.input_stride = score.stride(0),
|
||||
};
|
||||
}
|
||||
|
||||
template <auto* kernel_func, size_t smem_bytes>
|
||||
void setup_kernel_smem_once() {
|
||||
static const cudaError_t result = []() -> cudaError_t {
|
||||
#ifdef USE_ROCM
|
||||
auto func_ptr = reinterpret_cast<const void*>(kernel_func);
|
||||
cudaError_t err = cudaGetLastError();
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
"persistent_topk failed: ", cudaGetErrorString(err));
|
||||
#else
|
||||
auto func_ptr = kernel_func;
|
||||
TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
|
||||
#endif
|
||||
return cudaFuncSetAttribute(
|
||||
func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||
}();
|
||||
|
||||
TORCH_CHECK(
|
||||
result == cudaSuccess,
|
||||
"Failed to set kernel shared memory limit: ", cudaGetErrorString(result));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void large_context_topk(
|
||||
const torch::Tensor& logits, torch::Tensor& indices,
|
||||
const torch::Tensor& seq_lens,
|
||||
std::optional<torch::Tensor> row_starts = std::nullopt) {
|
||||
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
|
||||
TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor");
|
||||
TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor");
|
||||
if (row_starts.has_value()) {
|
||||
TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor");
|
||||
}
|
||||
|
||||
const auto params = vllm::get_params(logits, seq_lens, row_starts, indices);
|
||||
const int64_t batch_size = logits.size(0);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const dim3 grid(static_cast<uint32_t>(batch_size));
|
||||
const dim3 block(vllm::kThreadsPerBlock);
|
||||
|
||||
vllm::setup_kernel_smem_once<vllm::topk_kernel, vllm::kSmem>();
|
||||
vllm::topk_kernel<<<grid, block, vllm::kSmem, stream>>>(params);
|
||||
|
||||
const cudaError_t result = cudaGetLastError();
|
||||
TORCH_CHECK(result == cudaSuccess,
|
||||
"large_context_topk kernel failed: ", cudaGetErrorString(result));
|
||||
}
|
||||
@@ -2,7 +2,6 @@
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <torch/version.h>
|
||||
|
||||
@@ -73,7 +72,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor prefix_output,"
|
||||
" Tensor prefix_lse,"
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse) -> ()");
|
||||
" Tensor suffix_lse,"
|
||||
" int!? prefill_tokens_with_context,"
|
||||
" Tensor? output_scale=None) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
@@ -109,12 +110,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
|
||||
"Tensor input, Tensor input_global_scale) -> ()");
|
||||
ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant);
|
||||
#endif
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||
@@ -191,10 +197,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
||||
|
||||
ops.def(
|
||||
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, "
|
||||
"Tensor? "
|
||||
"row_starts_opt) -> ()");
|
||||
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
|
||||
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
|
||||
"Tensor workspace, int k, int max_seq_len) -> ()");
|
||||
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
|
||||
|
||||
// Layernorm-quant
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
@@ -332,47 +337,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_mm("
|
||||
" Tensor A,"
|
||||
" Tensor B,"
|
||||
" Tensor group_scales,"
|
||||
" int group_size,"
|
||||
" Tensor channel_scales,"
|
||||
" Tensor token_scales,"
|
||||
" ScalarType? out_type,"
|
||||
" str? maybe_schedule"
|
||||
") -> Tensor");
|
||||
// pack scales
|
||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||
// encode and reorder weight matrix
|
||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS w4a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_moe_mm("
|
||||
" Tensor! out_tensors,"
|
||||
" Tensor a_tensors,"
|
||||
" Tensor b_tensors,"
|
||||
" Tensor a_scales,"
|
||||
" Tensor b_scales,"
|
||||
" Tensor b_group_scales,"
|
||||
" int b_group_size,"
|
||||
" Tensor expert_offsets,"
|
||||
" Tensor problem_sizes,"
|
||||
" Tensor a_strides,"
|
||||
" Tensor b_strides,"
|
||||
" Tensor c_strides,"
|
||||
" Tensor group_scale_strides,"
|
||||
" str? maybe_schedule"
|
||||
") -> ()");
|
||||
ops.def(
|
||||
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
|
||||
"Tensor)");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
// Dequantization for GGML.
|
||||
@@ -409,20 +373,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
|
||||
ops.def(
|
||||
"mxfp8_experts_quant("
|
||||
@@ -455,44 +405,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"-> int");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
|
||||
|
||||
// Out variant
|
||||
// TODO: Add {at::Tag::out_variant} tag and update all call sites
|
||||
// to use the functional variant once vLLM upgrades PyTorch.
|
||||
// See pytorch/pytorch#176117.
|
||||
ops.def(
|
||||
"scaled_fp4_quant.out(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
||||
"-> ()");
|
||||
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
|
||||
|
||||
// Fused SiLU+Mul+NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
|
||||
"output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
|
||||
&silu_and_mul_scaled_fp4_experts_quant);
|
||||
|
||||
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
|
||||
// of the given capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
@@ -596,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" int block_size_in_bytes, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
||||
|
||||
// Batch swap: submit all block copies in a single driver call.
|
||||
cache_ops.def(
|
||||
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
|
||||
" Tensor sizes) -> ()");
|
||||
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
cache_ops.def(
|
||||
"reshape_and_cache(Tensor key, Tensor value,"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user