Compare commits

..

10 Commits

Author SHA1 Message Date
Stefano Castagnetta
c284a6671c [Bugfix] Restrict TRTLLM attention to SM100, fixing GB300 (SM103) hang (#38730)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
(cherry picked from commit 6183cae1bd)
2026-04-01 12:11:03 -07:00
Chauncey
3a30a1a6a8 [Misc] Rename think_start_str/think_end_str to reasoning_start_str/reasoning_end_str (#38242)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
(cherry picked from commit cbe7d18096)
2026-04-01 12:10:53 -07:00
Juan Pérez de Algaba
29982d48b3 (security) Enforce frame limit in VideoMediaIO (#38636)
Signed-off-by: jperezde <jperezde@redhat.com>
(cherry picked from commit 58ee614221)
2026-04-01 12:10:40 -07:00
Yifan Qiao
1dbbafd3f3 [Feat][v1] Simple yet General CPU KV Cache Offloading (#37160)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
(cherry picked from commit 91e4521f9f)
2026-04-01 01:03:14 -07:00
Lucas Wilkinson
0ee3b7fc3d [Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit eb47454987)
2026-04-01 01:02:58 -07:00
Matthew Bonanni
268bed9cf3 [Bugfix][Async] Fix async spec decoding with hybrid models (#38556)
Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com>
(cherry picked from commit 757068dc65)
2026-04-01 01:02:35 -07:00
Jiangyun Zhu
bcc0fdd0f3 [CI] fix LM Eval Qwen3.5 Models (B200) (#38632)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
(cherry picked from commit ea7bfde6e4)
2026-04-01 01:02:20 -07:00
wang.yuqi
69b8bd4b33 [CI Failure] pin colmodernvbert revision (#38612)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit 719735d6c5)
2026-04-01 01:02:04 -07:00
Li, Jiang
12449f9492 [Bugfix][CPU] Skip set_num_threads after thread binding (#38535)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
(cherry picked from commit 6557f4937f)
2026-03-30 23:01:42 -07:00
haosdent
b92312dfd7 [CI] Fix SPLADE pooler test broken by #38139 (#38495)
Signed-off-by: haosdent <haosdent@gmail.com>
(cherry picked from commit a08b7733fd)
2026-03-30 21:52:13 -07:00
621 changed files with 9524 additions and 34067 deletions

View File

@@ -5,7 +5,6 @@ steps:
depends_on: []
device: amd_cpu
no_plugin: true
soft_fail: true
commands:
- >
docker build
@@ -21,3 +20,11 @@ steps:
- docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}"
env:
DOCKER_BUILDKIT: "1"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 1
- exit_status: -10 # Agent was lost
limit: 1
- exit_status: 1 # Machine occasionally fail
limit: 1

View File

@@ -13,14 +13,12 @@ steps:
- tests/kernels/attention/test_cpu_attn.py
- tests/kernels/moe/test_cpu_fused_moe.py
- tests/kernels/test_onednn.py
- tests/kernels/test_awq_int4_to_int8.py
commands:
- |
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
pytest -x -v -s tests/kernels/test_onednn.py"
- label: CPU-Compatibility Tests
depends_on: []

View File

@@ -36,7 +36,6 @@
"model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm",
"ignore-eos": "",
"temperature": 0,
"num_prompts": 200
}
},
@@ -128,4 +127,4 @@
}
}
]
}
}

View File

@@ -22,7 +22,6 @@
"hf_split": "test",
"no_stream": "",
"no_oversample": "",
"temperature": 0,
"num_prompts": 200
}
},

View File

@@ -26,7 +26,6 @@
"model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm",
"ignore-eos": "",
"temperature": 0,
"num_prompts": 200
}
},

View File

@@ -26,7 +26,6 @@
"model": "meta-llama/Llama-3.1-8B-Instruct",
"backend": "vllm",
"ignore-eos": "",
"temperature": 0,
"num_prompts": 200
}
},

View File

@@ -21,7 +21,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -48,7 +47,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -75,7 +73,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -103,7 +100,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -131,7 +127,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -156,7 +151,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
}

View File

@@ -13,7 +13,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -31,7 +30,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -49,7 +47,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
},
@@ -70,7 +67,6 @@
"backend": "vllm",
"dataset_name": "sharegpt",
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
"temperature": 0,
"num_prompts": 200
}
}

View File

@@ -239,29 +239,13 @@ fi
# --- Docker housekeeping ---
cleanup_docker
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin "$REGISTRY"
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 936637512419.dkr.ecr.us-east-1.amazonaws.com
# --- Build or pull test image ---
IMAGE="${IMAGE_TAG_XPU:-${image_name}}"
echo "Using image: ${IMAGE}"
if docker image inspect "${IMAGE}" >/dev/null 2>&1; then
echo "Image already exists locally, skipping pull"
if [[ -n "${IMAGE_TAG_XPU:-}" ]]; then
echo "Using prebuilt XPU image: ${IMAGE_TAG_XPU}"
docker pull "${IMAGE_TAG_XPU}"
else
echo "Image not found locally, waiting for lock..."
flock /tmp/docker-pull.lock bash -c "
if docker image inspect '${IMAGE}' >/dev/null 2>&1; then
echo 'Image already pulled by another runner'
else
echo 'Pulling image...'
timeout 900 docker pull '${IMAGE}'
fi
"
echo "Pull step completed"
echo "Using prebuilt XPU image: ${image_name}"
docker pull "${image_name}"
fi
remove_docker_container() {

View File

@@ -42,7 +42,6 @@ docker run \
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
python3 examples/basic/offline_inference/generate.py --model OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc --block-size 64 --enforce-eager --max-model-len 8192
cd tests
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
pytest -v -s v1/engine

View File

@@ -790,7 +790,7 @@ steps:
- tests/kernels/helion/
- vllm/platforms/rocm.py
commands:
- pip install helion==0.3.3
- pip install helion
- pytest -v -s kernels/helion/

View File

@@ -2,6 +2,14 @@ group: Benchmarks
depends_on:
- image-build
steps:
- label: Benchmarks
timeout_in_minutes: 20
working_dir: "/vllm-workspace/.buildkite"
source_file_dependencies:
- benchmarks/
commands:
- bash scripts/run-benchmarks.sh
- label: Benchmarks CLI Test
timeout_in_minutes: 20
source_file_dependencies:

View File

@@ -72,7 +72,6 @@ steps:
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
- tests/compile/passes/test_fusion_attn.py
- tests/compile/passes/test_mla_attn_quant_fusion.py
- tests/compile/passes/test_silu_mul_quant_fusion.py
- tests/compile/passes/distributed/test_fusion_all_reduce.py
- tests/compile/fullgraph/test_full_graph.py
@@ -80,7 +79,6 @@ steps:
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
- nvidia-smi
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
- pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_devices=2 is not set
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py

View File

@@ -224,20 +224,6 @@ steps:
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 $IMAGE_TAG "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code"
- label: MessageQueue TCP Multi-Node (2 GPUs)
timeout_in_minutes: 10
working_dir: "/vllm-workspace/tests"
num_devices: 1
num_nodes: 2
no_plugin: true
optional: true
source_file_dependencies:
- vllm/distributed/device_communicators/shm_broadcast.py
- vllm/distributed/parallel_state.py
- tests/distributed/test_mq_tcp_multinode.py
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 1 $IMAGE_TAG "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py" "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py"
- label: Distributed NixlConnector PD accuracy (4 GPUs)
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"
@@ -308,23 +294,3 @@ steps:
commands:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py
- label: RayExecutorV2 (4 GPUs)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:
- vllm/v1/executor/ray_executor_v2.py
- vllm/v1/executor/abstract.py
- vllm/v1/executor/multiproc_executor.py
- tests/distributed/test_ray_v2_executor.py
- tests/distributed/test_ray_v2_executor_e2e.py
- tests/distributed/test_pipeline_parallel.py
- tests/basic_correctness/test_basic_correctness.py
commands:
- export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1
- export NCCL_CUMEM_HOST_ENABLE=0
- pytest -v -s distributed/test_ray_v2_executor.py
- pytest -v -s distributed/test_ray_v2_executor_e2e.py
- pytest -v -s distributed/test_pipeline_parallel.py -k "ray"
- TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray"

View File

@@ -13,8 +13,8 @@ steps:
- pytest -v -s distributed/test_eplb_algo.py
- pytest -v -s distributed/test_eplb_utils.py
- label: EPLB Execution # 17min
timeout_in_minutes: 27
- label: EPLB Execution
timeout_in_minutes: 20
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:

View File

@@ -2,16 +2,6 @@ group: Kernels
depends_on:
- image-build
steps:
- label: vLLM IR Tests
timeout_in_minutes: 10
working_dir: "/vllm-workspace/"
source_file_dependencies:
- vllm/ir
- vllm/kernels
commands:
- pytest -v -s tests/ir
- pytest -v -s tests/kernels/ir
- label: Kernels Core Operation Test
timeout_in_minutes: 75
source_file_dependencies:
@@ -29,7 +19,6 @@ steps:
- vllm/v1/attention
# TODO: remove this dependency (https://github.com/vllm-project/vllm/issues/32267)
- vllm/model_executor/layers/attention
- vllm/utils/flashinfer.py
- tests/kernels/attention
commands:
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
@@ -140,7 +129,7 @@ steps:
- vllm/utils/import_utils.py
- tests/kernels/helion/
commands:
- pip install helion==0.3.3
- pip install helion
- pytest -v -s kernels/helion/

View File

@@ -18,6 +18,5 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal/generation/test_phi4siglip.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/generation/test_phi4siglip.py
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'

30
.github/CODEOWNERS vendored
View File

@@ -2,20 +2,16 @@
# for more info about CODEOWNERS file
# This lists cover the "core" components of vLLM that require careful review
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng @vadiklyutiy
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
/vllm/lora @jeejeelee
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
/vllm/model_executor/layers/mamba @tdoublep @tomeras91
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516 @vadiklyutiy
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
/vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516
/vllm/model_executor/model_loader @22quinn
/vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/ir @ProExpertProg
/vllm/kernels/ @ProExpertProg @tjtanaa
/vllm/kernels/helion @ProExpertProg @zou3519
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
@@ -51,9 +47,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/vllm/v1/attention @LucasWilkinson @MatthewBonanni
/vllm/v1/attention/backend.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill
/vllm/v1/attention/backends/mla @pavanimajety
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety @vadiklyutiy
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety
/vllm/v1/attention/backends/triton_attn.py @tdoublep
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516 @vadiklyutiy
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
/vllm/v1/sample @22quinn @houseroad @njhill
/vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni
@@ -75,9 +71,8 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/distributed/test_pipeline_parallel.py @youkaichao
/tests/distributed/test_same_node.py @youkaichao
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
/tests/evals @mgoin @vadiklyutiy
/tests/evals @mgoin
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
/tests/kernels/ir @ProExpertProg @tjtanaa
/tests/models @DarkLight1337 @ywang96
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
@@ -87,7 +82,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
/tests/weight_loading @mgoin @youkaichao @yewentao256
/tests/lora @jeejeelee
/tests/models/language/generation/test_hybrid.py @tdoublep @tomeras91
/tests/models/language/generation/test_hybrid.py @tdoublep
/tests/v1/kv_connector/nixl_integration @NickLucche
/tests/v1/kv_connector @ApostaC @orozery
/tests/v1/kv_offload @ApostaC @orozery
@@ -131,14 +126,9 @@ mkdocs.yaml @hmellor
/vllm/platforms/xpu.py @jikunshang
/docker/Dockerfile.xpu @jikunshang
# Nemotron-specific files
/vllm/model_executor/models/*nemotron* @tomeras91
/vllm/transformers_utils/configs/*nemotron* @tomeras91
/tests/**/*nemotron* @tomeras91
# Qwen-specific files
/vllm/model_executor/models/qwen* @sighingnow @vadiklyutiy
/vllm/transformers_utils/configs/qwen* @sighingnow @vadiklyutiy
/vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow
/vllm/model_executor/models/qwen* @sighingnow
# MTP-specific files
/vllm/model_executor/models/deepseek_mtp.py @luccafong
@@ -154,7 +144,7 @@ mkdocs.yaml @hmellor
# Kernels
/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep
/vllm/v1/attention/ops/triton_unified_attention.py @tdoublep
/vllm/model_executor/layers/fla @ZJY0516 @vadiklyutiy
/vllm/model_executor/layers/fla @ZJY0516
# ROCm related: specify owner with write access to notify AMD folks for careful code review
/vllm/**/*rocm* @tjtanaa

View File

@@ -28,7 +28,6 @@ jobs:
});
const hasReadyLabel = pr.labels.some(l => l.name === 'ready');
const hasVerifiedLabel = pr.labels.some(l => l.name === 'verified');
const { data: mergedPRs } = await github.rest.search.issuesAndPullRequests({
q: `repo:${context.repo.owner}/${context.repo.repo} is:pr is:merged author:${pr.user.login}`,
@@ -36,10 +35,10 @@ jobs:
});
const mergedCount = mergedPRs.total_count;
if (hasReadyLabel || hasVerifiedLabel || mergedCount >= 4) {
core.info(`Check passed: verified label=${hasVerifiedLabel}, ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
if (hasReadyLabel || mergedCount >= 4) {
core.info(`Check passed: ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
} else {
core.setFailed(`PR must have the 'verified' or 'ready' (which also triggers tests) label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
core.setFailed(`PR must have the 'ready' label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
}
pre-commit:

View File

@@ -39,7 +39,7 @@ repos:
rev: 0.11.1
hooks:
- id: pip-compile
args: [requirements/test.in, -c, requirements/common.txt, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
files: ^requirements/test\.(in|txt)$
- id: pip-compile
alias: pip-compile-rocm

View File

@@ -309,7 +309,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "v4.4.2")
set(CUTLASS_REVISION "v4.2.1")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -340,8 +340,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
@@ -488,6 +490,185 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# CUTLASS MLA Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
@@ -512,6 +693,55 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MLA_ARCHS)
endif()
# CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
@@ -557,6 +787,36 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"in CUDA target architectures.")
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
#
# Machete kernels
@@ -627,6 +887,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
@@ -676,12 +964,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
"csrc/libtorch_stable/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -696,299 +979,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}")
endif()
#
# CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch)
#
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# CUTLASS MoE kernels (moved from _C to _C_stable_libtorch)
#
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
#
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
#
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
#
# W4A8 kernels (moved from _C to _C_stable_libtorch)
#
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch
@@ -997,7 +987,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SOURCES ${VLLM_STABLE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
@@ -1011,10 +1000,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA)
# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
endif()
#
@@ -1030,6 +1015,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/gpt_oss_router_gemm.cu"
"csrc/moe/router_gemm.cu")
endif()

View File

@@ -1,264 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark: Fused FP8 output quantization in merge_attn_states
Compares fused vs unfused approaches for producing FP8-quantized merged
attention output:
1. Fused CUDA -- single CUDA kernel (merge + FP8 quant)
2. Fused Triton -- single Triton kernel (merge + FP8 quant)
3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant
4. Unfused Triton -- Triton merge + torch.compiled FP8 quant
Usage:
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16
"""
import argparse
import itertools
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
)
# ---------------------------------------------------------------------------
# Configuration defaults
# ---------------------------------------------------------------------------
NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096]
# (label, num_heads, head_size) — num_heads is for TP=1
HEAD_CONFIGS = [
("DeepSeek-V3 MLA", 128, 128),
("Llama-70B", 64, 128),
("Llama-8B", 32, 128),
]
TP_SIZES = [1, 2, 4, 8]
INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
QUANTILES = [0.5, 0.2, 0.8]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def short_dtype(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")
def make_inputs(
num_tokens: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
"""Create random prefix/suffix outputs and LSEs."""
prefix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
suffix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
# Sprinkle some inf values to exercise edge-case paths
mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
prefix_lse[mask] = float("inf")
mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
suffix_lse[mask2] = float("inf")
return prefix_output, suffix_output, prefix_lse, suffix_lse
def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes):
"""Build (num_tokens, num_heads, head_size, dtype_str) config tuples,
applying TP division to num_heads and skipping invalid combos."""
configs = []
for (_, nh, hs), nt, dtype, tp in itertools.product(
head_configs, num_tokens_list, input_dtypes, tp_sizes
):
nh_tp = nh // tp
if nh_tp >= 1:
configs.append((nt, nh_tp, hs, short_dtype(dtype)))
return configs
def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark merge_attn_states fused FP8 quantization"
)
parser.add_argument(
"--num-tokens",
type=int,
nargs="+",
default=None,
help=f"Override token counts (default: {NUM_TOKENS_LIST})",
)
parser.add_argument(
"--tp",
type=int,
nargs="+",
default=None,
help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})",
)
parser.add_argument(
"--dtype",
type=str,
nargs="+",
default=None,
help="Input dtypes (e.g. bfloat16 float16 float32). "
f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}",
)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Parse args and build configs before decorators
# ---------------------------------------------------------------------------
args = parse_args()
num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST
tp_sizes = args.tp if args.tp else TP_SIZES
if args.dtype:
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype]
else:
input_dtypes = INPUT_DTYPES
configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes)
torch._dynamo.config.recompile_limit = 8888
# ---------------------------------------------------------------------------
# Benchmark function
# ---------------------------------------------------------------------------
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_heads", "head_size", "dtype_str"],
x_vals=configs,
line_arg="provider",
line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"],
line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"],
styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")],
ylabel="us",
plot_name="merge_attn_states FP8 (fused vs unfused)",
args={},
)
)
@default_vllm_config()
def benchmark(num_tokens, num_heads, head_size, dtype_str, provider):
input_dtype = getattr(torch, dtype_str)
fp8_dtype = current_platform.fp8_dtype()
prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs(
num_tokens, num_heads, head_size, input_dtype
)
output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda")
if provider == "fused_cuda":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_cuda(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "fused_triton":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_triton(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "unfused_cuda":
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_cuda(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
else: # unfused_triton
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_triton(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
device_name = current_platform.get_device_name()
print(f"Device: {device_name}")
print(f"Token counts: {num_tokens_list}")
print(f"TP sizes: {tp_sizes}")
print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}")
print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}")
benchmark.run(print_data=True)
if __name__ == "__main__":
with torch.inference_mode():
main()

View File

@@ -1,211 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from itertools import product
import torch
import torch.nn.functional as F
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
dtype: torch.dtype
group_size: int # Changed from list[int] to int
def description(self):
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x DT {self.dtype} "
f"x GS {self.group_size}"
)
def get_bench_params() -> list[bench_params_t]:
"""Test configurations covering common model sizes."""
NUM_TOKENS = [16, 128, 512, 2048]
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
DTYPES = [torch.float16, torch.bfloat16]
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
)
return bench_params
# Reference implementations
def unfused_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then per-tensor quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Per-tensor quantize (no group_size used here)
silu_out, _ = ops.scaled_fp8_quant(silu_out)
def unfused_groupwise_fp8_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int, # Changed from list[int]
):
"""Unfused: SiLU+Mul then group-wise quantize."""
hidden = x.shape[-1] // 2
gate, up = x.split(hidden, dim=-1)
# SiLU(gate) * up
silu_out = F.silu(gate) * up
# Group quantize - use group_size directly
silu_out, _ = per_token_group_quant_fp8(
silu_out, group_size=group_size, use_ue8m0=False
)
def fused_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
):
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
out, _ = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=quant_dtype,
is_scale_transposed=False,
)
# Bench functions
def bench_fn(
x: torch.Tensor,
quant_dtype: torch.dtype,
group_size: int,
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1
globals = {
"x": x,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(x, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
"""Run benchmarks for all implementations."""
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
scale = 1 / params.hidden_size
x = (
torch.randn(
params.num_tokens,
params.hidden_size * 2,
dtype=params.dtype,
device="cuda",
)
* scale
)
timers = []
# Unfused per-tensor FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# Unfused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)
# Fused group-wise FP8
timers.append(
bench_fn(
x,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
"fused_groupwise_fp8_impl",
)
)
return timers
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def main():
torch.set_default_device("cuda")
bench_params = get_bench_params()
print(f"Running {len(bench_params)} benchmark configurations...")
print(
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
)
print()
timers = []
for bp in tqdm(bench_params):
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
timers.extend(result_timers)
print("\n" + "=" * 80)
print("FINAL COMPARISON - ALL RESULTS")
print("=" * 80)
print_timers(timers)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
def get_batch_size_range(max_batch_size):
return [2**x for x in range(14) if 2**x <= max_batch_size]
def get_model_params(config):
if config.architectures[0] in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
):
num_experts = config.n_routed_experts
hidden_size = config.hidden_size
elif config.architectures[0] in ("GptOssForCausalLM",):
num_experts = config.num_local_experts
hidden_size = config.hidden_size
else:
raise ValueError(f"Unsupported architecture: {config.architectures}")
return num_experts, hidden_size
def get_benchmark(model, max_batch_size, trust_remote_code):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=get_batch_size_range(max_batch_size),
x_log=False,
line_arg="provider",
line_vals=[
"torch",
"vllm",
],
line_names=["PyTorch", "vLLM"],
styles=([("blue", "-"), ("red", "-")]),
ylabel="TFLOPs",
plot_name=f"{model} router gemm throughput",
args={},
)
)
def benchmark(batch_size, provider):
config = get_config(model=model, trust_remote_code=trust_remote_code)
num_experts, hidden_size = get_model_params(config)
mat_a = torch.randn(
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn(
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
bias = torch.randn(
num_experts, dtype=torch.bfloat16, device="cuda"
).contiguous()
is_hopper_or_blackwell = current_platform.is_device_capability(
90
) or current_platform.is_device_capability_family(100)
allow_dsv3_router_gemm = (
is_hopper_or_blackwell
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
)
allow_gpt_oss_router_gemm = (
is_hopper_or_blackwell
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
has_bias = False
if allow_gpt_oss_router_gemm:
has_bias = True
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
def runner():
if has_bias:
F.linear(mat_a, mat_b, bias)
else:
F.linear(mat_a, mat_b)
elif provider == "vllm":
def runner():
if allow_dsv3_router_gemm:
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
elif allow_gpt_oss_router_gemm:
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
else:
raise ValueError("Unsupported router gemm")
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
runner, quantiles=quantiles
)
def tflops(t_ms):
flops = 2 * batch_size * hidden_size * num_experts
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
parser.add_argument("--max-batch-size", default=16, type=int)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
# Get the benchmark function
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
# Run performance benchmark
benchmark.run(print_data=True)

View File

@@ -1,162 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Benchmarks the fused Triton bilinear position-embedding kernel against
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
#
# == Usage Examples ==
#
# Default benchmark:
# python3 benchmark_vit_bilinear_pos_embed.py
#
# Custom parameters:
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
import itertools
import torch
from vllm.model_executor.models.qwen3_vl import (
pos_embed_interpolate_native,
triton_pos_embed_interpolate,
)
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# (h, w) configurations to benchmark
h_w_configs = [
(16, 16),
(32, 32),
(48, 48),
(64, 64),
(128, 128),
(32, 48),
(60, 80),
]
# Temporal dimensions
t_range = [1]
configs = list(itertools.product(t_range, h_w_configs))
def get_benchmark(
num_grid_per_side: int,
spatial_merge_size: int,
hidden_dim: int,
dtype: torch.dtype,
device: str,
):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["t", "h_w"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["native", "triton"],
line_names=["Native (PyTorch)", "Triton"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name=(
f"vit-bilinear-pos-embed-"
f"grid{num_grid_per_side}-"
f"dim{hidden_dim}-"
f"{dtype}"
),
args={},
)
)
def benchmark(t, h_w, provider):
h, w = h_w
torch.manual_seed(42)
embed_weight = (
torch.randn(
num_grid_per_side * num_grid_per_side,
hidden_dim,
device=device,
dtype=dtype,
)
* 0.25
)
quantiles = [0.5, 0.2, 0.8]
if provider == "native":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: pos_embed_interpolate_native(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
else:
assert HAS_TRITON, "Triton not available"
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_pos_embed_interpolate(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark bilinear position embedding interpolation."
)
parser.add_argument(
"--num-grid-per-side",
type=int,
default=48,
help="Position embedding grid size (default: 48 for Qwen3-VL)",
)
parser.add_argument(
"--spatial-merge-size",
type=int,
default=2,
help="Spatial merge size (default: 2)",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=1152,
help="Embedding hidden dimension (default: 1152 for Qwen3-VL)",
)
parser.add_argument(
"--device",
type=str,
choices=["cuda:0", "cuda:1"],
default="cuda:0",
)
parser.add_argument(
"--save-path",
type=str,
default="./vit_pos_embed/",
)
args = parser.parse_args()
dtype = torch.bfloat16
bench = get_benchmark(
args.num_grid_per_side,
args.spatial_merge_size,
args.hidden_dim,
dtype,
args.device,
)
bench.run(print_data=True, save_path=args.save_path)

View File

@@ -373,7 +373,6 @@ if (ENABLE_X86_ISA)
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp")

View File

@@ -39,7 +39,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG c0ec424fd8a546d0cbbf4bf050bbcfe837c55afb
GIT_TAG 29210221863736a08f71a866459e368ad1ac4a95
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@@ -3,33 +3,22 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <limits>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/w8a8/fp8/common.cuh"
#include "../dispatch_utils.h"
namespace vllm {
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, typename output_t, const uint NUM_THREADS,
bool USE_FP8_OUTPUT>
template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel(
output_t* output, float* output_lse, const scalar_t* prefix_output,
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride, const uint prefix_num_tokens,
const float* output_scale) {
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
// Outputs store pack_size elements of output_t, which is smaller for FP8.
using input_pack_t = uint4;
using output_pack_t =
std::conditional_t<USE_FP8_OUTPUT,
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
uint4>;
const uint output_head_stride) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -52,45 +41,8 @@ __global__ void merge_attn_states_kernel(
head_idx * output_head_stride;
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
output_t* output_head_ptr = output + dst_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
// Pre-invert scale: multiplication is faster than division
float fp8_scale_inv = 1.0f;
if constexpr (USE_FP8_OUTPUT) {
fp8_scale_inv = 1.0f / *output_scale;
}
// If token_idx >= prefix_num_tokens, just copy from suffix
if (token_idx >= prefix_num_tokens) {
if (pack_offset < head_size) {
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
suffix_head_ptr)[pack_offset / pack_size];
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
}
}
if (output_lse != nullptr && pack_idx == 0) {
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
output_lse[head_idx * num_tokens + token_idx] = s_lse;
}
return;
}
// For tokens within prefix range, merge prefix and suffix
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
@@ -101,34 +53,20 @@ __global__ void merge_attn_states_kernel(
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
continuing the pipeline then yields NaN. Root cause: with chunked prefill
a batch may be split into two chunks; if a request in that batch has no
prefix hit, every LSE entry for that request's position is -inf, and at
prefix hit, every LSE entry for that requests position is -inf, and at
this moment we merge cross-attention at first. For now we simply emit
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
this problem.
*/
if (std::isinf(max_lse)) {
if (pack_offset < head_size) {
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size];
if constexpr (USE_FP8_OUTPUT) {
// Convert prefix values to FP8 (since -inf means no data,
// prefix_output is expected to be zeros)
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
}
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
p_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
@@ -146,43 +84,30 @@ __global__ void merge_attn_states_kernel(
const float s_scale = s_se / out_se;
if (pack_offset < head_size) {
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size];
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;
// Compute merged values in float32
float o_out_f[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
}
// Convert and store
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
o_out_f[i], fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
output_pack_t o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f[i]);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
}
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
o_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
@@ -209,73 +134,50 @@ __global__ void merge_attn_states_kernel(
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
USE_FP8_OUTPUT) \
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \
USE_FP8_OUTPUT> \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride, \
prefix_num_tokens, output_scale_ptr); \
num_heads, head_size, prefix_head_stride, output_head_stride); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states.
* @param prefill_tokens_with_context Number of prefill tokens with context
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
* is computed by merging prefix_output and suffix_output. For remaining tokens
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
* from suffix_output.
* @param output_scale Optional scalar tensor for FP8 static quantization.
* When provided, output must be FP8 dtype.
*/
template <typename scalar_t>
void merge_attn_states_launcher(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
void merge_attn_states_launcher(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint prefix_head_stride = prefix_output.stride(1);
const uint output_head_stride = output.stride(1);
// Thread mapping is based on input BF16 pack_size
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
const uint prefix_num_tokens =
prefill_tokens_with_context.has_value()
? static_cast<uint>(prefill_tokens_with_context.value())
: num_tokens;
TORCH_CHECK(prefix_num_tokens <= num_tokens,
"prefix_num_tokens must be <= num_tokens");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
float* output_scale_ptr = nullptr;
if (output_scale.has_value()) {
output_scale_ptr = output_scale.value().data_ptr<float>();
}
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const uint threads_per_head = head_size / pack_size;
@@ -287,22 +189,14 @@ void merge_attn_states_launcher(
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
if (output_scale.has_value()) {
// FP8 output path - dispatch on output FP8 type
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
});
} else {
// Original BF16/FP16/FP32 output path
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
}
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>( \
output, output_lse, prefix_output, prefix_lse, suffix_output, \
suffix_lse, prefill_tokens_with_context, output_scale); \
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
}
void merge_attn_states(torch::Tensor& output,
@@ -310,21 +204,6 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse,
std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
if (output_scale.has_value()) {
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
"output must be FP8 when output_scale is provided, got: ",
output.scalar_type());
} else {
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
"output dtype (", output.scalar_type(),
") must match prefix_output dtype (",
prefix_output.scalar_type(), ") when output_scale is not set");
}
// Always dispatch on prefix_output (input) dtype
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
CALL_MERGE_ATTN_STATES_LAUNCHER);
const torch::Tensor& suffix_lse) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
}

View File

@@ -10,10 +10,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,

View File

@@ -24,8 +24,6 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include <cuda.h>
#endif
#if defined(__gfx942__)
@@ -75,59 +73,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
if (n == 0) return;
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
const int64_t* size_data = sizes.data_ptr<int64_t>();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
#else
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
#endif
}
namespace vllm {
// Grid: (num_layers, num_pairs)

View File

@@ -30,15 +30,13 @@
}()
namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") {
return FusedMOEAct::SiluAndMul;
} else if (act == "swigluoai") {
return FusedMOEAct::SwigluOAIAndMul;
} else if (act == "gelu") {
return FusedMOEAct::GeluAndMul;
} else {
TORCH_CHECK(false, "Invalid act type: " + act);
}
@@ -106,43 +104,6 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
}
}
template <typename scalar_t>
void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride, const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
vec_op::FP32Vec16 w1_vec(M_SQRT1_2);
vec_op::FP32Vec16 w2_vec(0.5);
alignas(64) float temp[16];
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto er_input_vec = gate_vec * w1_vec;
er_input_vec.save(temp);
for (int32_t i = 0; i < 16; ++i) {
temp[i] = std::erf(temp[i]);
}
vec_op::FP32Vec16 er_vec(temp);
auto gelu = gate_vec * w2_vec * (one_vec + er_vec);
auto gated_output_fp32 = up_vec * gelu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input,
@@ -157,9 +118,6 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
case FusedMOEAct::SiluAndMul:
silu_and_mul(input, output, m, n, input_stride, output_stride);
return;
case FusedMOEAct::GeluAndMul:
gelu_and_mul(input, output, m, n, input_stride, output_stride);
return;
default:
TORCH_CHECK(false, "Unsupported act type.");
}

View File

@@ -8,7 +8,7 @@ Generate CPU attention dispatch switch cases and kernel instantiations.
import os
# Head dimensions divisible by 32 (support all ISAs)
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256, 512]
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
# Head dimensions divisible by 16 but not 32 (VEC16 only)
HEAD_DIMS_16 = [80, 112]

View File

@@ -117,14 +117,6 @@ inline void parallel_for(int n, const func_t& f) {
#endif
}
inline int get_thread_num() {
#if defined(_OPENMP)
return omp_get_thread_num();
#else
return 0;
#endif
}
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) {

View File

@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
template <typename T> inline bool can_use_brgemm(int M);
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
template <> inline bool can_use_brgemm<int8_t>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<uint8_t>(int M) { return M > 4; }
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
@@ -40,17 +40,9 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
}
inline int64_t get_4bit_block_k_size(int64_t group_size) {
return group_size > 128 ? 128 : group_size;
}
// pack weight into vnni format
// pack weight to vnni format
at::Tensor convert_weight_packed(at::Tensor& weight);
// pack weight to vnni format for int4 (adapted from sglang)
std::tuple<at::Tensor, at::Tensor, at::Tensor>
convert_weight_packed_scale_zp(at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
// moe implementations for int8 w8a8
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
@@ -241,31 +233,6 @@ void tinygemm_kernel(
int64_t strideBs,
bool brg);
// int4 scaled GEMM (adapted from sglang)
at::Tensor int4_scaled_mm_cpu(
at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, at::Tensor& w_scales, std::optional<at::Tensor> bias);
// int4 tinygemm kernel interface(adapted from sglang)
template <typename scalar_t>
void tinygemm_kernel(
scalar_t* C,
float* C_temp,
const uint8_t* A,
const float* scales_a,
const int32_t* qzeros_a,
const uint8_t* B,
const float* scales_b,
const int8_t* qzeros_b,
const int32_t* compensation,
int8_t* dqB_tmp,
int64_t M,
int64_t K,
int64_t lda,
int64_t ldc_f,
int64_t ldc_s,
bool store_out,
bool use_brgemm);
// TODO: debug print, remove me later
inline void print_16x32i(const __m512i x) {
int32_t a[16];

View File

@@ -1,755 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// Adapted from sgl-project/sglang
// https://github.com/sgl-project/sglang/pull/8226
#include <ATen/ATen.h>
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
#define BLOCK_N block_size_n()
#define BLOCK_M 128
template <bool sym_quant_act>
struct ActDtype;
template <>
struct ActDtype<true> {
using type = int8_t;
};
template <>
struct ActDtype<false> {
using type = uint8_t;
};
struct alignas(32) m256i_wrapper {
__m256i data;
};
#if defined(CPU_CAPABILITY_AVX512)
inline std::array<m256i_wrapper, 2> load_zps_4vnni(
const int8_t* __restrict__ zps) {
__m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps));
__m256i vzps_high =
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps + 8));
__m256i shuffle_mask =
_mm256_set_epi8(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3,
3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask);
vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask);
m256i_wrapper vzps_low_wp, vzps_high_wp;
vzps_low_wp.data = vzps_low;
vzps_high_wp.data = vzps_high;
return {vzps_low_wp, vzps_high_wp};
}
inline std::array<m256i_wrapper, 2> load_uint4_as_int8(
const uint8_t* __restrict__ qB) {
__m256i packed = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(qB));
const __m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i high = _mm256_srli_epi16(packed, 4);
high = _mm256_and_si256(high, low_mask);
__m256i low = _mm256_and_si256(packed, low_mask);
m256i_wrapper low_wp, high_wp;
low_wp.data = low;
high_wp.data = high;
return {low_wp, high_wp};
}
template <int N, int ldb>
void _dequant_weight_zp_only(const uint8_t* __restrict__ B, int8_t* dqB,
const int8_t* __restrict__ qzeros, int64_t K) {
#pragma GCC unroll 2
for (int n = 0; n < N; n += 16) {
auto [zps_low_wp, zps_high_wp] = load_zps_4vnni(&qzeros[n]);
auto zps_low = zps_low_wp.data;
auto zps_high = zps_high_wp.data;
for (int k = 0; k < K; k += 4) {
auto [vb_low_wp, vb_high_wp] =
load_uint4_as_int8(B + ldb * k + n / 2 * 4);
auto vb_low = vb_low_wp.data;
auto vb_high = vb_high_wp.data;
vb_high = _mm256_sub_epi8(vb_high, zps_high);
vb_low = _mm256_sub_epi8(vb_low, zps_low);
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4),
vb_low);
_mm256_storeu_si256(
reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high);
}
}
}
template <bool sym_quant_act, int N, bool accum>
void _dequant_and_store(float* __restrict__ output,
const int32_t* __restrict__ input,
const float* __restrict__ scale_a,
const int32_t* __restrict__ zp_a,
const float* __restrict__ scale_b,
const int32_t* __restrict__ comp_b, int M, int ldi,
int ldo, int ldsa = 1) {
for (int m = 0; m < M; ++m) {
float a_scale = *(scale_a + m * ldsa);
__m512 va_scale = _mm512_set1_ps(a_scale);
int32_t a_zp;
__m512i va_zp;
if constexpr (!sym_quant_act) {
a_zp = *(zp_a + m * ldsa);
va_zp = _mm512_set1_epi32(a_zp);
}
int n = 0;
#pragma GCC unroll 2
for (; n < N; n += 16) {
__m512i vc = _mm512_loadu_si512(input + m * ldi + n);
if constexpr (!sym_quant_act) {
__m512i vb_comp = _mm512_loadu_si512(comp_b + n);
vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp));
}
__m512 vc_f = _mm512_cvtepi32_ps(vc);
__m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale);
__m512 vb_s = _mm512_loadu_ps(scale_b + n);
vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s);
if constexpr (accum) {
__m512 vo = _mm512_loadu_ps(output + m * ldo + n);
_mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul));
} else {
_mm512_storeu_ps(output + m * ldo + n, vc_f_mul);
}
}
for (; n < N; ++n) {
float dq_val;
if constexpr (sym_quant_act) {
dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n];
} else {
dq_val = (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale *
scale_b[n];
}
if constexpr (accum) {
output[m * ldo + n] += dq_val;
} else {
output[m * ldo + n] = dq_val;
}
}
}
}
#else
template <int N, int ldb>
void _dequant_weight_zp_only(const uint8_t* B, int8_t* dqB,
const int8_t* qzeros, int64_t K) {
for (int k = 0; k < K; ++k) {
for (int n = 0; n < N / 2; ++n) {
int32_t b = (int32_t)B[k * ldb + n];
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n];
dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n];
}
}
}
#endif
#if defined(CPU_CAPABILITY_AVX512)
inline __m512i combine_m256i(__m256i a, __m256i b) {
__m512i c = _mm512_castsi256_si512(a);
return _mm512_inserti64x4(c, b, 1);
}
inline __m512i combine_m256i(std::array<m256i_wrapper, 2> two_256) {
return combine_m256i(two_256[0].data, two_256[1].data);
}
static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) {
__m512i zero = _mm512_setzero_si512();
__mmask64 blt0 = _mm512_movepi8_mask(b);
return _mm512_mask_sub_epi8(a, blt0, zero, a);
}
template <bool sym_quant_act, int M, int N, int ldb>
void _dequant_gemm_accum_small_M(float* __restrict__ C, const uint8_t* A,
const float* scales_a, const int32_t* qzeros_a,
const uint8_t* B, const float* scales_b,
const int8_t* qzeros_b, int64_t K, int64_t lda,
int64_t ldc) {
constexpr int COLS = N / 16;
__m512i ones = _mm512_set1_epi8(1);
__m512i va;
__m512i vb[COLS];
__m512i vc[M * COLS];
__m512 vscales[COLS];
__m512i vzps[COLS];
__m512i vcompensate[COLS];
Unroll<COLS>{}([&](auto i) {
vscales[i] = _mm512_loadu_ps(scales_b + i * 16);
vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16));
if constexpr (!sym_quant_act) {
vcompensate[i] = _mm512_setzero_epi32();
}
});
Unroll<M * COLS>{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); });
auto compute = [&](auto i, int k) {
constexpr const int row = i / COLS;
constexpr const int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k));
}
if constexpr (row == 0) {
int B_offset = k * ldb + col * 16 * 2;
vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset));
vb[col] = _mm512_sub_epi8(vb[col], vzps[col]);
if constexpr (!sym_quant_act) {
vcompensate[col] = _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]);
}
_mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0);
}
if constexpr (sym_quant_act) {
auto vsb = _mm512_sign_epi8(vb[col], va);
auto vabsa = _mm512_sign_epi8(va, va);
vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb);
} else {
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
}
};
constexpr const int unroll = 4;
int k = 0;
for (; k < K / 4 / unroll; k++) {
Unroll<unroll>{}(
[&](auto i) { Unroll<M * COLS>{}(compute, 4 * (k * unroll + i)); });
}
k *= 4 * unroll;
for (; k < K; k += 4) {
Unroll<M * COLS>{}(compute, k);
}
auto store = [&](auto i) {
constexpr const int row = i / COLS;
constexpr const int col = i % COLS;
__m512 vc_float;
if constexpr (!sym_quant_act) {
vc[i] = _mm512_sub_epi32(
vc[i], _mm512_mullo_epi32(vcompensate[col],
_mm512_set1_epi32(*(qzeros_a + row))));
}
vc_float = _mm512_cvtepi32_ps(vc[i]);
vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row)));
vc_float = _mm512_mul_ps(vc_float, vscales[col]);
auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16);
vc_float = _mm512_add_ps(vc_float, vc_old);
_mm512_storeu_ps(C + row * ldc + col * 16, vc_float);
};
Unroll<M * COLS>{}(store);
}
#define CALL_DEQUANT_GEMM_ACCUM_SMALL_M(M) \
_dequant_gemm_accum_small_M<sym_quant_act, M, N, ldb>( \
C, A, scales_a, qzeros_a, B, scales_b, qzeros_b, K, lda, ldc);
#endif
template <bool sym_quant_act, int N, int ldb>
void _dequant_gemm_accum(float* C, const uint8_t* A, const float* scales_a,
const int32_t* qzeros_a, const uint8_t* B,
const float* scales_b, const int8_t* qzeros_b,
const int32_t* compensation, int8_t* dqB, int64_t M,
int64_t K, int64_t lda, int64_t ldc, bool use_brgemm) {
#if defined(CPU_CAPABILITY_AVX512)
if (!use_brgemm) {
switch (M) {
case 1:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(1);
break;
case 2:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(2);
break;
case 3:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(3);
break;
case 4:
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(4);
break;
default:
TORCH_CHECK(false, "tinygemm_kernel: unexpected M for AVX path!");
}
return;
}
_dequant_weight_zp_only<N, ldb>(B, dqB, qzeros_b, K);
using Tin = typename ActDtype<sym_quant_act>::type;
Tin* A_ptr = (Tin*)A;
if (use_brgemm) {
int32_t C_i32[M * N];
at::native::cpublas::brgemm(M, N, K, lda, N /*ldb*/, N /*ldc*/,
false /* add_C */, A_ptr, dqB, C_i32,
true /* is_vnni */);
_mm_prefetch(B + N * K / 2, _MM_HINT_T0);
_mm_prefetch(A + K, _MM_HINT_T0);
_dequant_and_store<sym_quant_act, N, true>(C, C_i32, scales_a, qzeros_a,
scales_b, compensation, M,
N /*ldi*/, ldc, 1 /*ldsa*/);
} else
#endif
{
TORCH_CHECK(false, "tinygemm_kernel: scalar path not implemented!");
}
}
template <int N>
inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) {
if (bias_ptr) {
for (int i = 0; i < m; ++i) {
int j = 0;
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 bias_vec = _mm512_loadu_ps(bias_ptr + j);
_mm512_storeu_ps(y_buf + i * N + j, bias_vec);
}
#endif
for (; j < N; ++j) {
y_buf[i * N + j] = bias_ptr[j];
}
}
} else {
for (int i = 0; i < m; ++i) {
int j = 0;
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 zero_vec = _mm512_setzero_ps();
_mm512_storeu_ps(y_buf + i * N + j, zero_vec);
}
#endif
for (; j < N; ++j) {
y_buf[i * N + j] = 0;
}
}
}
}
template <int N, typename out_dtype>
inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m,
int64_t lda) {
for (int i = 0; i < m; ++i) {
int j = 0;
if constexpr (std::is_same<out_dtype, float>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
_mm512_storeu_ps(c_ptr + i * lda + j, y_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = y_buf[i * N + j];
}
} else if constexpr (std::is_same<out_dtype, at::BFloat16>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
__m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
y_bf16_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]);
}
} else if constexpr (std::is_same<out_dtype, at::Half>::value) {
#if defined(CPU_CAPABILITY_AVX512)
#pragma GCC unroll 2
for (; j < N; j += 16) {
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
__m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
y_fp16_vec);
}
#endif
for (; j < N; ++j) {
c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]);
}
} else {
TORCH_CHECK(false, "Unsupported output dtype");
}
}
}
void fill_val_stub(int32_t* __restrict__ output, int32_t value, int64_t size) {
using iVec = at::vec::Vectorized<int32_t>;
constexpr int VecSize = iVec::size();
const iVec fill_val_vec = iVec(value);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - VecSize; d += VecSize) {
fill_val_vec.store(output + d);
}
for (; d < size; ++d) {
output[d] = value;
}
}
template <bool sym_quant_act, typename act_dtype, typename out_dtype>
void _da8w4_linear_impl(
act_dtype* __restrict__ input, const float* __restrict__ input_scales,
const int32_t* __restrict__ input_qzeros,
const uint8_t* __restrict__ weight, const float* __restrict__ weight_scales,
const int8_t* __restrict__ weight_qzeros, const float* __restrict__ bias,
out_dtype* __restrict__ output, float* __restrict__ output_temp,
int8_t* __restrict__ dequant_weight_temp, int64_t M, int64_t N, int64_t K,
int64_t num_groups) {
const bool use_brgemm = can_use_brgemm<act_dtype>(M);
int64_t block_m = [&]() -> long {
if (M <= 48) {
return M;
} else if (M < 64) {
return 32;
} else if (M < 96) {
return 64;
} else {
return 128;
}
}();
int64_t Mc = div_up(M, block_m);
bool parallel_on_M = M > 128;
int64_t Nc = N / BLOCK_N;
int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc;
int64_t group_size = div_up(K, num_groups);
int64_t _block_k = get_4bit_block_k_size(group_size);
int64_t Kc = K / _block_k;
int64_t block_per_group = group_size / _block_k;
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
int tid = get_thread_num();
float* C_tmp = output_temp + tid * block_m * BLOCK_N;
int8_t* dqB_tmp = dequant_weight_temp + tid * _block_k * BLOCK_N;
for (const auto i : c10::irange(begin, end)) {
int64_t mc = parallel_on_M ? i / Nc : 0;
int64_t nc = parallel_on_M ? i % Nc : i;
int64_t mc_end = parallel_on_M ? mc + 1 : Mc;
for (int mci = mc; mci < mc_end; ++mci) {
int64_t m_size =
mci * block_m + block_m > M ? M - mci * block_m : block_m;
auto bias_data = bias ? bias + nc * BLOCK_N : nullptr;
copy_bias<BLOCK_N>(bias_data, C_tmp, m_size);
for (int kci = 0; kci < Kc; ++kci) {
int32_t* compensation_ptr =
sym_quant_act
? nullptr
: (int32_t*)(void*)(weight +
(nc * Kc + kci) *
(BLOCK_N *
(_block_k / 2 + sizeof(int32_t))) +
_block_k * BLOCK_N / 2);
_dequant_gemm_accum<sym_quant_act, BLOCK_N, BLOCK_N / 2>(
/*C*/ C_tmp,
/*A*/ (uint8_t*)input + mci * block_m * K + kci * _block_k,
/*scales_a*/ input_scales + mci * block_m,
/*qzeros_a*/ input_qzeros + mci * block_m,
/*B*/ weight + (nc * Kc + kci) *
(BLOCK_N * (_block_k / 2 + sizeof(int32_t))),
/*scales_b*/ weight_scales + nc * BLOCK_N * num_groups +
kci / block_per_group * BLOCK_N,
/*qzeros_b*/ weight_qzeros + nc * BLOCK_N * num_groups +
kci / block_per_group * BLOCK_N,
/*Bcomp*/ compensation_ptr,
/*dqB_tmp*/ dqB_tmp,
/*M*/ m_size,
/*K*/ _block_k,
/*lda*/ K,
/*ldc*/ BLOCK_N,
/*use_brgemm*/ use_brgemm);
}
store_out<BLOCK_N>(C_tmp, output + mci * block_m * N + nc * BLOCK_N,
m_size, N /*lda*/);
}
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
} // anonymous namespace
std::tuple<at::Tensor, at::Tensor, at::Tensor>
convert_int4_weight_packed_with_compensation(const at::Tensor& weight,
const at::Tensor& scales,
const at::Tensor& qzeros) {
TORCH_CHECK(weight.dim() == 2,
"DA8W4 CPU: Weight should be a 2D tensor for packing");
TORCH_CHECK(
weight.size(1) % 2 == 0,
"DA8W4 CPU: Weight should have even number of columns for packing");
auto new_scales = scales;
auto new_qzeros = qzeros;
if (new_scales.dim() == 1) {
new_scales.unsqueeze_(1);
}
new_scales = new_scales.to(at::kFloat);
if (new_qzeros.dim() == 1) {
new_qzeros.unsqueeze_(1);
}
new_qzeros = new_qzeros.to(at::kChar);
int64_t N = weight.size(0);
int64_t K = weight.size(1);
int64_t G = scales.size(1);
int64_t group_size = K / G;
int64_t _block_k = get_4bit_block_k_size(group_size);
constexpr int block_n = block_size_n();
int64_t Nc = N / block_n;
int64_t Kc = K / _block_k;
auto weight_view = weight.view({Nc, block_n, Kc, _block_k});
at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous();
at::Tensor blocked_weight;
at::Tensor blocked_scales =
new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
at::Tensor blocked_qzeros =
new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) -
new_qzeros.view({Nc, block_n, G, -1});
weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, _block_k});
at::Tensor compensation = weight_sub_qzero.sum(-1);
compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt);
int64_t buffer_size_nbytes =
_block_k * block_n / 2 + block_n * sizeof(int32_t);
blocked_weight = at::empty({Nc, Kc, buffer_size_nbytes}, weight.options());
auto weight_ptr = weight_reordered.data_ptr<uint8_t>();
auto compensation_ptr = compensation.data_ptr<int32_t>();
auto blocked_weight_ptr = blocked_weight.data_ptr<uint8_t>();
int64_t num_blocks = Nc * Kc;
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
auto in_ptr = weight_ptr + i * _block_k * block_n;
auto out_ptr =
blocked_weight_ptr + i * block_n * (_block_k / 2 + sizeof(int32_t));
int32_t* comp_in_prt = compensation_ptr + i * block_n;
int32_t* comp_out_prt =
(int32_t*)(void*)(blocked_weight_ptr +
i * block_n * (_block_k / 2 + sizeof(int32_t)) +
_block_k * block_n / 2);
constexpr int n_group_size = 8;
constexpr int vnni_size = 4;
constexpr int n_group = block_n / n_group_size;
for (int nb = 0; nb < n_group; nb += 2) {
for (int k = 0; k < _block_k; k += vnni_size) {
for (int ni = 0; ni < n_group_size; ++ni) {
for (int ki = 0; ki < vnni_size; ++ki) {
int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n;
int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n;
int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size +
k * block_n / 2 + ki;
uint8_t src_1 = *(in_ptr + src_idx_1);
uint8_t src_2 = *(in_ptr + src_idx_2);
uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4);
*(out_ptr + dst_idx) = dst;
}
}
}
}
for (int nb = 0; nb < block_n; nb++) {
*(comp_out_prt + nb) = *(comp_in_prt + nb);
}
}
});
return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales),
std::move(blocked_qzeros));
}
std::tuple<at::Tensor, at::Tensor> autoawq_to_int4pack(at::Tensor qweight,
at::Tensor qzeros) {
auto bitshifts = at::tensor({0, 4, 1, 5, 2, 6, 3, 7}, at::kInt) * 4;
auto qweight_unsq = qweight.unsqueeze(-1);
auto unpacked = at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF;
auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte);
auto qzeros_unsq = qzeros.unsqueeze(-1);
auto qzeros_unpacked = at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF;
auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte);
return std::make_tuple(qweight_final, qzeros_final);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales) {
auto res = autoawq_to_int4pack(qweight, qzeros);
auto _qweight = std::get<0>(res);
auto _qzeros = std::get<1>(res);
auto _scales = scales;
_qzeros = _qzeros.transpose(-2, -1).contiguous();
_scales = _scales.transpose(-2, -1).contiguous();
if (_qweight.dim() == 3) {
int64_t E = _qweight.size(0);
int64_t K = _qweight.size(2);
int64_t G = _scales.size(2);
int64_t group_size = K / G;
int64_t _block_k = get_4bit_block_k_size(group_size);
int64_t block_n = block_size_n();
int64_t Nc = _qweight.size(1) / block_n;
int64_t Kc = K / _block_k;
int64_t buffer_size_nbytes =
_block_k * block_n / 2 + block_n * sizeof(int32_t);
auto blocked_weight =
at::empty({E, Nc, Kc, buffer_size_nbytes}, _qweight.options());
auto blocked_scales =
at::empty({E, Nc, G, block_n}, _scales.options()).to(at::kFloat);
auto blocked_qzeros =
at::empty({E, Nc, G, block_n}, _qzeros.options()).to(at::kChar);
for (int i = 0; i < _qweight.size(0); i++) {
auto res_ = convert_int4_weight_packed_with_compensation(
_qweight[i], _scales[i], _qzeros[i]);
blocked_weight[i] = std::get<0>(res_);
blocked_scales[i] = std::get<1>(res_);
blocked_qzeros[i] = std::get<2>(res_);
}
_qweight = blocked_weight;
_scales = blocked_scales;
_qzeros = blocked_qzeros;
} else {
auto res_ = convert_int4_weight_packed_with_compensation(_qweight, _scales,
_qzeros);
_qweight = std::get<0>(res_);
_scales = std::get<1>(res_);
_qzeros = std::get<2>(res_);
}
return std::make_tuple(_qweight, _qzeros, _scales);
}
at::Tensor int4_scaled_mm_cpu_with_quant(const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& weight_scales,
const at::Tensor& weight_qzeros,
const std::optional<at::Tensor>& bias,
at::ScalarType output_dtype) {
RECORD_FUNCTION("vllm::int4_scaled_mm_cpu_with_quant",
std::vector<c10::IValue>({input, weight}));
int64_t M_a = input.size(0);
int64_t K_a = input.size(1);
int64_t lda = input.stride(0);
const auto st = input.scalar_type();
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf,
"int4_scaled_mm_cpu_with_quant: expect A to be bfloat16 or half.");
constexpr bool sym_quant_act = false;
using Tin = typename ActDtype<sym_quant_act>::type;
int64_t act_buffer_size =
M_a * K_a + M_a * sizeof(float) + M_a * sizeof(int32_t);
auto act_buffer =
at::empty({act_buffer_size}, input.options().dtype(at::kByte));
auto Aq_data = act_buffer.data_ptr<uint8_t>();
auto As_data = reinterpret_cast<float*>(Aq_data + M_a * K_a);
auto Azp_data = reinterpret_cast<int32_t*>(As_data + M_a);
fill_val_stub(Azp_data, 128, M_a);
auto out_sizes = input.sizes().vec();
int64_t N = weight_scales.size(0) * weight_scales.size(-1);
out_sizes.back() = N;
auto output = at::empty(out_sizes, input.options());
int64_t Nc = weight.size(0);
int64_t Kc = weight.size(1);
int64_t _block_k = K_a / Kc;
TORCH_CHECK(N == Nc * BLOCK_N, "DA8W4: weight and input shapes mismatch");
int64_t num_groups = weight_scales.size(1);
const uint8_t* b_ptr = weight.data_ptr<uint8_t>();
const float* b_scales_ptr = weight_scales.data_ptr<float>();
const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr<int8_t>();
const float* bias_ptr =
bias.has_value() ? bias.value().data_ptr<float>() : nullptr;
int num_threads = at::get_num_threads();
int64_t temp_buffer_size = num_threads * BLOCK_M * BLOCK_N * sizeof(float) +
num_threads * _block_k * BLOCK_N;
auto c_temp_buffer =
at::empty({temp_buffer_size}, input.options().dtype(at::kChar));
float* c_temp_ptr = (float*)((void*)(c_temp_buffer.data_ptr<int8_t>()));
int8_t* dqB_temp_ptr =
(int8_t*)((void*)(c_temp_ptr + num_threads * BLOCK_M * BLOCK_N));
#define LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act) \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, \
"int4_scaled_mm_cpu", [&] { \
const scalar_t* __restrict__ A_data = input.data_ptr<scalar_t>(); \
scalar_t* __restrict__ c_ptr = output.data_ptr<scalar_t>(); \
at::parallel_for(0, M_a, 0, [&](int64_t begin, int64_t end) { \
for (int64_t m = begin; m < end; ++m) { \
quantize_row_int8<scalar_t>(Aq_data + m * K_a, As_data[m], \
A_data + m * lda, K_a); \
} \
}); \
_da8w4_linear_impl<sym_quant_act, Tin, scalar_t>( \
Aq_data, As_data, Azp_data, b_ptr, b_scales_ptr, b_qzeros_ptr, \
bias_ptr, c_ptr, c_temp_ptr, dqB_temp_ptr, M_a, N, K_a, \
num_groups); \
});
LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act);
return output;
}
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out,
const float* __restrict__ input, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += Vec::size()) {
fVec x0 = fVec::loadu(input + d);
fVec x1 = fVec::loadu(input + d + fVec::size());
Vec res = convert_from_float_ext<scalar_t>(x0, x1);
res.store(out + d);
}
}
} // anonymous namespace
template <typename scalar_t>
void tinygemm_kernel(scalar_t* C, float* C_temp, const uint8_t* A,
const float* scales_a, const int32_t* qzeros_a,
const uint8_t* B, const float* scales_b,
const int8_t* qzeros_b, const int32_t* compensation,
int8_t* dqB_tmp, int64_t M, int64_t K, int64_t lda,
int64_t ldc_f, int64_t ldc_s, bool store_out,
bool use_brgemm) {
_dequant_gemm_accum<false, BLOCK_N, BLOCK_N / 2>(
C_temp, A, scales_a, qzeros_a, B, scales_b, qzeros_b, compensation,
dqB_tmp, M, K, lda, ldc_f, use_brgemm);
if (store_out) {
for (int64_t m = 0; m < M; ++m) {
copy_stub<scalar_t>(C + m * ldc_s, C_temp + m * ldc_f, BLOCK_N);
}
}
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
TYPE * C, float* C_temp, const uint8_t* A, const float* scales_a, \
const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, \
const int8_t* qzeros_b, const int32_t* compensation, int8_t* dqB_tmp, \
int64_t M, int64_t K, int64_t lda, int64_t ldc_f, int64_t ldc_s, \
bool store_out, bool use_brgemm)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
at::Tensor& w_scales,
std::optional<at::Tensor> bias) {
return int4_scaled_mm_cpu_with_quant(x, w, w_scales, w_zeros, bias,
x.scalar_type());
}

View File

@@ -79,14 +79,6 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni);
// Adapted from sglang: INT4 W4A8 kernels
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
at::Tensor& w_scales,
std::optional<at::Tensor> bias);
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
@@ -293,18 +285,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
&int8_scaled_mm_with_quant);
// Adapted from sglang: INT4 W4A8 kernels
ops.def(
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
"Tensor scales) -> (Tensor, Tensor, Tensor)");
ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
&convert_weight_packed_scale_zp);
ops.def(
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
#endif
// CPU attention kernels

View File

@@ -3,8 +3,8 @@
#pragma once
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <cassert>
#ifdef USE_ROCM

View File

@@ -6,16 +6,14 @@
#include <cstdio>
#include <cstdlib>
#include <torch/headeronly/util/shim_utils.h>
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {

View File

@@ -1,6 +1,7 @@
#pragma once
#include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute {
////////////////////////////////////////////////////////////////////

View File

@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE void
begin() {
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
cute::Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(

View File

@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE void
begin() {
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
cute::Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(

View File

@@ -1,7 +1,5 @@
#pragma once
#include <torch/csrc/stable/tensor.h>
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
@@ -54,7 +52,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::stable::Tensor const& tensor) {
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
@@ -70,8 +68,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(
std::optional<torch::stable::Tensor> const& tensor) {
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
@@ -120,8 +117,8 @@ struct ScaledEpilogue
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
@@ -163,9 +160,9 @@ struct ScaledEpilogueBias
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& bias) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -223,11 +220,10 @@ struct ScaledEpilogueBiasAzp
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& bias) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -302,11 +298,11 @@ struct ScaledEpilogueBiasAzpToken
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
std::optional<torch::stable::Tensor> const& bias) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

View File

@@ -3,14 +3,6 @@
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
// (stable ABI) targets. When compiled under the stable ABI target,
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
// use torch::stable::Tensor instead.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#endif
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
@@ -23,12 +15,6 @@
namespace vllm::c3x {
#ifdef TORCH_TARGET_VERSION
using TensorType = torch::stable::Tensor;
#else
using TensorType = torch::Tensor;
#endif
using namespace cute;
template <typename T>
@@ -98,7 +84,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(TensorType const& tensor) {
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
@@ -114,7 +100,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<TensorType> const& tensor) {
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
@@ -172,8 +158,8 @@ struct ScaledEpilogue
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
@@ -217,9 +203,9 @@ struct ScaledEpilogueBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& bias) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -260,9 +246,9 @@ struct ScaledEpilogueColumnBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& bias) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -318,10 +304,10 @@ struct ScaledEpilogueBiasAzp
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& azp_adj,
std::optional<TensorType> const& bias) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -394,11 +380,11 @@ struct ScaledEpilogueBiasAzpToken
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& azp_adj,
TensorType const& azp,
std::optional<TensorType> const& bias) {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

View File

@@ -1,21 +1,6 @@
#pragma once
// This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
#include <torch/all.h>
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
@@ -70,35 +55,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(TorchTensor const& tensor,
static inline auto make_cute_layout(torch::Tensor const& tensor,
std::string_view name = "tensor") {
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
} else {
return tensor.stride(idx);
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
@@ -112,7 +97,7 @@ static inline auto make_cute_layout(TorchTensor const& tensor,
template <typename Stride>
static inline auto maybe_make_cute_layout(
std::optional<TorchTensor> const& tensor,
std::optional<torch::Tensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
@@ -136,12 +121,12 @@ template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<torch::headeronly::Half> {
struct equivalent_cutlass_type<c10::Half> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
struct equivalent_cutlass_type<c10::BFloat16> {
using type = cutlass::bfloat16_t;
};
@@ -149,8 +134,8 @@ struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
@@ -161,15 +146,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = torch::headeronly::Half;
using type = c10::Half;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = torch::headeronly::BFloat16;
using type = c10::BFloat16;
};
// get equivalent torch::headeronly::ScalarType tag from compile time type
// get equivalent c10::ScalarType tag from compile time type
template <typename T>
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;

View File

@@ -49,15 +49,6 @@
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
// Half types dispatch (Half + BFloat16)
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
// Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \

View File

@@ -27,111 +27,4 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_s,
int64_t group_size, double eps, double int8_min,
double int8_max);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch);
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
// FP4/NVFP4 ops
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets);
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& output_block_scale,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
#endif

View File

@@ -1,175 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
#endif
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 quantization kernel");
}
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::stable::empty(
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
torch::stable::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::stable::empty(
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
} else {
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
torch::headeronly::ScalarType::Byte,
std::nullopt, device);
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -1,87 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
const torch::stable::Tensor& A,
const torch::stable::Tensor& B,
const torch::stable::Tensor& A_sf,
const torch::stable::Tensor& B_sf,
const torch::stable::Tensor& alpha) {
// Make sure we're on A's device.
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) {
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
if (runtimeVersion < 12080) return false;
// Only report support when the SM-specific kernel was actually compiled in,
// so the Python-side backend selector does not choose CUTLASS and then hit
// TORCH_CHECK_NOT_IMPLEMENTED (or worse, fall through to Marlin).
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (cuda_device_capability >= 100 && cuda_device_capability < 120)
return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (cuda_device_capability >= 120 && cuda_device_capability < 130)
return true;
#endif
return false;
}

View File

@@ -1,22 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,22 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,23 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,52 +0,0 @@
#pragma once
#include <torch/csrc/stable/tensor.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
} // namespace vllm

View File

@@ -1,24 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm100_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,25 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,24 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,25 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,220 +0,0 @@
#include <stddef.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h"
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm75(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm80(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}

View File

@@ -1,38 +0,0 @@
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
*/
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
}
void cutlass_scaled_mm_azp_sm90(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
azp, bias);
}
#endif

View File

@@ -1,451 +0,0 @@
#include <cudaTypedefs.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data_caller(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} else if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif
return false;
}
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
// or CUDA 12.8 and SM100 (Blackwell)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12030;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
// Check for strides and alignment
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
if (bias) {
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
if (version_num >= 120) {
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if (version_num >= 100 && version_num < 120) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 75) {
// Turing
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100 && version_num < 110) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
if (version_num >= 90 && version_num < 100) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90 or 100");
}
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
blockscale_offsets, is_gated);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
"no cutlass_scaled_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_batched_moe_mm_data: no "
"cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if (bias) {
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
}
if (azp) {
STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
}
STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
// azp & bias types
STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(!azp ||
azp->scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
"currently bias dtype must match output dtype ",
c.scalar_type());
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
// Turing
STD_TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@@ -31,174 +31,6 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
// Check if cutlass grouped gemm is supported for CUDA devices of the given
// capability
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
// CUTLASS w8a8 grouped GEMM
ops.def(
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM. It takes topk_ids as an input, and computes expert_offsets
// (token start indices of each expert). In addition to this, it computes
// problem sizes for each expert's multiplication used by the two mms called
// from fused MoE operation, and arrays with permutations required to shuffle
// and de-shuffle the input/output of the fused operation.
ops.def(
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets, "
" bool is_gated) -> ()");
// compute per-expert problem sizes from expert_first_token_offset
// produced by vLLM's moe_permute kernel
ops.def(
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
" Tensor expert_first_token_offset, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int n, int k, bool swap_ab) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM in batched expert format. It takes expert_num_tokens
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()");
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
// Out variant
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
// This registration is now migrated to stable ABI
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
// yet in torch/headeronly), the tag should be applied from Python
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
// with the .impl remaining in C++.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
#endif
}
@@ -214,45 +46,6 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&per_token_group_quant_8bit_packed));
ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8));
// CUTLASS scaled_mm ops
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
// FP4/NVFP4 ops
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
#endif
}
// These capability-check functions take only primitive args (no tensors), so
// there is no device to dispatch on. CompositeExplicitAutograd makes them
// available for all backends. This is the stable ABI equivalent of calling
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
#ifndef USE_ROCM
ops.impl("cutlass_scaled_mm_supports_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
ops.impl("cutlass_group_gemm_supported",
TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
ops.impl("cutlass_scaled_mm_supports_fp4",
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
#endif
}

View File

@@ -1,17 +1,10 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h>
#include <cuda_runtime.h>
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
// Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches.
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {

View File

@@ -21,7 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
int64_t null_block_id;
int64_t pad_slot_id;
bool delta_softplus;
bool cache_enabled;

View File

@@ -118,17 +118,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index;
if (cache_indices == nullptr) {
cache_index = batch_id;
} else if (params.cache_enabled) {
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
} else {
cache_index = cache_indices[batch_id];
}
// Skip batch entries whose cache index maps to the null block (padding).
if (cache_indices != nullptr && cache_index == params.null_block_id){
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
@@ -535,7 +527,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool varlen,
int64_t null_block_id,
int64_t pad_slot_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
@@ -552,7 +544,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.dstate = dstate;
params.n_groups = n_groups;
params.dim_ngroups_ratio = dim / n_groups;
params.null_block_id = null_block_id;
params.pad_slot_id = pad_slot_id;
params.delta_softplus = delta_softplus;
@@ -666,7 +658,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t null_block_id,
int64_t pad_slot_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
@@ -813,7 +805,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices,
has_initial_state,
varlen,
null_block_id,
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,

View File

@@ -0,0 +1,144 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "gpt_oss_router_gemm.cuh"
void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB,
__nv_bfloat16* gC, __nv_bfloat16* bias,
int batch_size, int output_features,
int input_features, cudaStream_t stream) {
static int const WARP_TILE_M = 16;
static int const TILE_M = WARP_TILE_M;
static int const TILE_N = 8;
static int const TILE_K = 64;
static int const STAGES = 16;
static int const STAGE_UNROLL = 4;
static bool const PROFILE = false;
CUtensorMap weight_map{};
CUtensorMap activation_map{};
constexpr uint32_t rank = 2;
uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features};
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
uint32_t box_size[rank] = {TILE_K, TILE_M};
uint32_t elem_stride[rank] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(
&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank,
gB, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for weight_map, error code=",
static_cast<int>(res));
size[1] = batch_size;
box_size[1] = TILE_N;
res = cuTensorMapEncodeTiled(
&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
rank, gA, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for activation_map, error code=",
static_cast<int>(res));
int smem_size = STAGES * STAGE_UNROLL *
(TILE_M * TILE_K * sizeof(__nv_bfloat16) +
TILE_N * TILE_K * sizeof(__nv_bfloat16));
gpuErrChk(cudaFuncSetAttribute(
gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
dim3 grid(tiles_m, tiles_n);
dim3 block(384);
cudaLaunchConfig_t config;
cudaLaunchAttribute attrs[1];
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
config.attrs = attrs;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
cudaLaunchKernelEx(
&config,
&gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map,
activation_map, nullptr);
}
void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output,
torch::Tensor input, torch::Tensor weight,
torch::Tensor bias) {
auto const batch_size = input.size(0);
auto const input_dim = input.size(1);
auto const output_dim = weight.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::BFloat16) {
launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(),
(__nv_bfloat16*)weight.data_ptr(),
(__nv_bfloat16*)output.mutable_data_ptr(),
(__nv_bfloat16*)bias.data_ptr(), batch_size,
output_dim, input_dim, stream);
} else {
throw std::invalid_argument("Unsupported dtype, only supports bfloat16");
}
}
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias) {
TORCH_CHECK(input.dim() == 2, "input must be 2D");
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");
TORCH_CHECK(input.sizes()[1] == weight.sizes()[1],
"input.size(1) must match weight.size(1)");
TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0],
"weight.size(0) must match bias.size(0)");
TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16,
"input tensor must be bfloat16");
TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16,
"weight tensor must be bfloat16");
TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16,
"bias tensor must be bfloat16");
gpt_oss_router_gemm_cuda_forward(output, input, weight, bias);
}

View File

@@ -0,0 +1,447 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_bf16.h"
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "cuda_pipeline.h"
#include <cuda.h>
#include <cuda/barrier>
#include <cuda/std/utility>
#include <cuda_runtime.h>
using barrier = cuda::barrier<cuda::thread_scope_block>;
namespace cde = cuda::device::experimental;
namespace ptx = cuda::ptx;
#define gpuErrChk(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
}
inline void gpuAssert(cudaError_t code, char const* file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort) {
throw std::runtime_error(cudaGetErrorString(code));
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
__device__ uint64_t gclock64() {
unsigned long long int rv;
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv));
return rv;
}
__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) {
int dst;
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = dst;
}
__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) {
int x, y;
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(x), "=r"(y)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = x;
rvi[1] = y;
}
__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) {
int x, y, z, w;
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = x;
rvi[1] = y;
rvi[2] = z;
rvi[3] = w;
}
__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2],
float c[4]) {
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
float const* C = reinterpret_cast<float const*>(&c[0]);
float* D = reinterpret_cast<float*>(&d[0]);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]),
"f"(C[3]));
}
__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4],
float c[4]) {
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
float const* C = reinterpret_cast<float const*>(&c[0]);
float* D = reinterpret_cast<float*>(&d[0]);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
}
__device__ void bar_wait(uint32_t bar_ptr, int phase) {
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(bar_ptr),
"r"(phase));
}
__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) {
uint32_t success;
#ifdef INTERNAL
asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory");
#endif
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(success)
: "r"(bar_ptr), "r"(phase));
return success;
}
__device__ uint32_t elect_one_sync() {
uint32_t pred = 0;
uint32_t laneid = 0;
asm volatile(
"{\n"
".reg .b32 %%rx;\n"
".reg .pred %%px;\n"
" elect.sync %%rx|%%px, %2;\n"
"@%%px mov.s32 %1, 1;\n"
" mov.s32 %0, %%rx;\n"
"}\n"
: "+r"(laneid), "+r"(pred)
: "r"(0xFFFFFFFF));
return pred;
}
#endif
struct Profile {
uint64_t start;
uint64_t weight_load_start;
uint64_t act_load_start;
uint64_t compute_start;
uint64_t complete;
};
template <int WARP_TILE_M, int TILE_M, int TILE_N, int TILE_K, int STAGES,
int STAGE_UNROLL, bool PROFILE>
__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel(
__nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations,
__nv_bfloat16* bias, int M, int N, int K,
const __grid_constant__ CUtensorMap weight_map,
const __grid_constant__ CUtensorMap activation_map,
Profile* profile = nullptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0)
profile[blockIdx.x].start = gclock64();
extern __shared__ __align__(128) char smem[];
__nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0];
__nv_bfloat16* sh_activations =
(__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K *
sizeof(__nv_bfloat16)];
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ barrier bar_wt_ready[STAGES];
__shared__ barrier bar_act_ready[STAGES];
__shared__ barrier bar_data_consumed[STAGES];
__shared__ float4 reduction_buffer[128];
__shared__ nv_bfloat16 sh_bias[TILE_M];
if (threadIdx.x == 0) {
for (int i = 0; i < STAGES; i++) {
init(&bar_wt_ready[i], 1);
init(&bar_act_ready[i], 1);
init(&bar_data_consumed[i], 32);
}
ptx::fence_proxy_async(ptx::space_shared);
asm volatile("prefetch.tensormap [%0];"
:
: "l"(reinterpret_cast<uint64_t>(&weight_map))
: "memory");
asm volatile("prefetch.tensormap [%0];"
:
: "l"(reinterpret_cast<uint64_t>(&activation_map))
: "memory");
}
__syncthreads();
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
int phase = 0;
int mib = blockIdx.x * TILE_M;
int ni = blockIdx.y * TILE_N;
float accum[4];
for (int i = 0; i < 4; i++) accum[i] = 0.f;
int const K_LOOPS_DMA =
(K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL));
int const K_LOOPS_COMPUTE = K_LOOPS_DMA;
// Data loading thread
if (warp_id >= 4 && elect_one_sync()) {
int stage = warp_id % 4;
bool weight_warp = warp_id < 8;
if (!weight_warp) {
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
for (int ki = 0; ki < K_LOOPS_DMA; ki++) {
int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL;
uint64_t desc_ptr_wt = reinterpret_cast<uint64_t>(&weight_map);
uint64_t desc_ptr_act = reinterpret_cast<uint64_t>(&activation_map);
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16);
int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16);
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
if (weight_warp)
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt));
if (!weight_warp)
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act));
if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp)
profile[blockIdx.x].weight_load_start = gclock64();
if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp)
profile[blockIdx.x].act_load_start = gclock64();
for (int i = 0; i < STAGE_UNROLL; i++) {
uint32_t smem_ptr_wt = __cvta_generic_to_shared(
&sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]);
uint32_t crd0 = k + i * TILE_K;
uint32_t crd1 = mib;
if (weight_warp)
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
: "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0),
"r"(crd1)
: "memory");
uint32_t smem_ptr_act = __cvta_generic_to_shared(
&sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]);
crd0 = k + i * TILE_K;
crd1 = ni;
if (!weight_warp)
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
: "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act),
"r"(crd0), "r"(crd1)
: "memory");
}
stage += 4;
if (stage >= STAGES) {
stage = warp_id % 4;
phase ^= 1;
}
}
// Wait for pending loads to be consumed before exiting, to avoid race
for (int i = 0; i < (STAGES / 4) - 1; i++) {
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
stage += 4;
if (stage >= STAGES) {
stage = warp_id % 4;
phase ^= 1;
}
}
}
// Compute threads
else if (warp_id < 4) {
// Sneak the bias load into the compute warps since they're just waiting for
// stuff anyway
if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x];
int stage = warp_id;
int phase = 0;
int lane_id_div8 = lane_id / 8;
int lane_id_mod8 = lane_id % 8;
int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0;
int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0;
int row_wt = lane_id_mod8 + lane_row_offset_wt;
int row_act = lane_id_mod8;
int row_offset_wt = (reinterpret_cast<uintptr_t>(sh_weights) / 128) % 8;
int row_offset_act = row_offset_wt;
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
bool weight_ready = bar_try_wait(bar_ptr_wt, phase);
bool act_ready = bar_try_wait(bar_ptr_act, phase);
#pragma unroll 2
for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) {
int next_stage = stage + 4;
int next_phase = phase;
if (next_stage >= STAGES) {
next_stage = warp_id;
next_phase ^= 1;
}
while (!weight_ready || !act_ready) {
weight_ready = bar_try_wait(bar_ptr_wt, phase);
act_ready = bar_try_wait(bar_ptr_act, phase);
}
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0)
profile[blockIdx.x].compute_start = gclock64();
if (ki + 1 < K_LOOPS_COMPUTE) {
weight_ready = bar_try_wait(
__cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase);
act_ready = bar_try_wait(
__cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase);
}
#pragma unroll
for (int su = 0; su < STAGE_UNROLL; su++) {
__nv_bfloat16* ptr_weights =
&sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K];
__nv_bfloat16* ptr_act =
&sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K];
#pragma unroll
for (int kii = 0; kii < TILE_K / 16; kii++) {
__nv_bfloat16 a[8];
__nv_bfloat16 b[4];
int col = 2 * kii + lane_col_offset_wt;
int col_sw = ((row_wt + row_offset_wt) % 8) ^ col;
ldmatrix4(a, __cvta_generic_to_shared(
&ptr_weights[row_wt * TILE_K + col_sw * 8]));
col = 2 * kii + lane_id_div8;
col_sw = ((row_act + row_offset_act) % 8) ^ col;
ldmatrix2(b, __cvta_generic_to_shared(
&ptr_act[row_act * TILE_K + 8 * col_sw]));
HMMA_16816(accum, a, b, accum);
}
}
uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]);
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c));
stage = next_stage;
phase = next_phase;
}
float4 accum4;
accum4.x = accum[0];
accum4.y = accum[1];
accum4.z = accum[2];
accum4.w = accum[3];
reduction_buffer[threadIdx.x] = accum4;
__syncthreads();
if (warp_id == 0) {
int mi = mib + warp_id * WARP_TILE_M;
int tm = mi + lane_id / 4;
int tn = ni + 2 * (lane_id % 4);
float4 accum1 = reduction_buffer[32 + threadIdx.x];
float4 accum2 = reduction_buffer[64 + threadIdx.x];
float4 accum3 = reduction_buffer[96 + threadIdx.x];
accum[0] = accum[0] + accum1.x + accum2.x + accum3.x;
accum[1] = accum[1] + accum1.y + accum2.y + accum3.y;
accum[2] = accum[2] + accum1.z + accum2.z + accum3.z;
accum[3] = accum[3] + accum1.w + accum2.w + accum3.w;
float bias_lo = __bfloat162float(sh_bias[tm - mib]);
float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]);
if (tn < N && tm < M)
output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo);
if (tn + 1 < N && tm < M)
output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo);
if (tn < N && tm + 8 < M)
output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi);
if (tn + 1 < N && tm + 8 < M)
output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi);
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
profile[blockIdx.x].complete = gclock64();
}
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

View File

@@ -108,15 +108,6 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],

View File

@@ -343,8 +343,6 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) {
static_assert(group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -359,10 +357,9 @@ __global__ void Marlin(
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
is_a_8bit || b_type == vllm::kFE4M3fn ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -376,7 +373,7 @@ __global__ void Marlin(
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride =
prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8);
prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
@@ -695,8 +692,9 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
@@ -1133,7 +1131,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (!is_8bit_scale) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
@@ -1142,7 +1140,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (!is_8bit_scale) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
@@ -1343,7 +1341,7 @@ __global__ void Marlin(
}
}
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
if constexpr (b_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -599,9 +599,6 @@ torch::Tensor moe_wna16_marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a,
const torch::Tensor& mat_b);
// gpt-oss optimized router GEMM kernel for SM90+
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias);
#endif

View File

@@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file
// gpt-oss optimized router GEMM kernel for SM90+
m.def(
"gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, "
"Tensor bias) -> ()");
m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm);
#endif
}

View File

@@ -53,12 +53,12 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void merge_attn_states(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale = std::nullopt);
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
@@ -143,14 +143,6 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);
#ifndef USE_ROCM
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed);
#endif
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
@@ -160,6 +152,12 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
#ifndef USE_ROCM
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& output_block_scale,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
@@ -227,6 +225,89 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_moe_mm(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp);
@@ -262,7 +343,7 @@ void selective_scan_fwd(
const std::optional<torch::Tensor>& query_start_loc,
const std::optional<torch::Tensor>& cache_indices,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx,

View File

@@ -2,9 +2,10 @@
#pragma once
#include <cuda.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
@@ -40,7 +41,7 @@ __global__ void get_group_gemm_starts(
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \
@@ -65,34 +66,23 @@ __global__ void get_group_gemm_starts(
namespace {
void run_get_group_gemm_starts(
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
torch::stable::Tensor& b_group_scales_ptrs,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) {
STD_TORCH_CHECK(a_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(
b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Int); // int4 8x packed into int32
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(
b_group_scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
// type is e4m3
STD_TORCH_CHECK(
out_tensors.scalar_type() ==
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_group_scales.dtype() ==
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
TORCH_CHECK(out_tensors.dtype() ==
torch::kBFloat16); // only support bf16 for now
// expect int64_t to avoid overflow during offset calculations
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
torch::headeronly::ScalarType::Long);
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
int num_experts = static_cast<int>(expert_offsets.size(0));
// logical k, n
@@ -100,16 +90,15 @@ void run_get_group_gemm_starts(
int64_t k = a_tensors.size(1);
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace
} // namespace

View File

@@ -14,12 +14,13 @@
#include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "core/registration.h"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "w4a8_utils.cuh"
@@ -167,40 +168,31 @@ struct W4A8GroupedGemmKernel {
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
"LayoutB_Reordered size must be divisible by 4 bytes");
static void grouped_mm(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales,
const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes_torch,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides) {
static void grouped_mm(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides) {
auto device = a_tensors.device();
auto device_id = device.index();
const torch::stable::accelerator::DeviceGuard device_guard(device_id);
auto stream = get_current_cuda_stream(device_id);
const at::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream(device_id);
int num_experts = static_cast<int>(expert_offsets.size(0));
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
torch::stable::Tensor a_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor out_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(device);
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
// get the correct offsets to pass to gemm
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
@@ -255,9 +247,9 @@ struct W4A8GroupedGemmKernel {
// Allocate workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::stable::Tensor workspace = torch::stable::empty(
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
device);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
// Run GEMM
GemmShuffled gemm;
@@ -302,20 +294,14 @@ using Kernel_256x128_2x1x1_Coop =
using Kernel_128x256_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
void mm_dispatch(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales,
const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
const std::string& schedule) {
void mm_dispatch(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides, const std::string& schedule) {
if (schedule == "Kernel_128x16_1x1x1_Coop") {
Kernel_128x16_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
@@ -372,23 +358,18 @@ void mm_dispatch(torch::stable::Tensor& out_tensors,
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else {
STD_TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
}
}
void mm(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales, const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides,
std::optional<std::string> maybe_schedule) {
// user has specified a schedule
if (maybe_schedule) {
@@ -425,27 +406,26 @@ void mm(torch::stable::Tensor& out_tensors,
a_strides, b_strides, c_strides, group_scale_strides, schedule);
}
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
STD_TORCH_CHECK(b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
STD_TORCH_CHECK(b_tensors.is_contiguous());
STD_TORCH_CHECK(b_tensors.is_cuda());
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
torch::Tensor const& b_tensors) {
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
TORCH_CHECK(b_tensors.is_contiguous());
TORCH_CHECK(b_tensors.is_cuda());
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
// These misalignments cause silent OOB unless run under Compute Sanitizer.
STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
// we will store the layout to an int32 tensor;
// this is the number of elements we need per layout
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors);
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
int num_experts = static_cast<int>(b_tensors.size(0));
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
@@ -455,7 +435,7 @@ encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
size_t num_int4_elems = 1ull * num_experts * n * k;
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
num_int4_elems);
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
TORCH_CHECK(ok, "unified_encode_int4b failed");
// construct the layout once; assumes each expert has the same layout
using LayoutType = LayoutB_Reordered;
@@ -476,28 +456,28 @@ encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
}
// save the packed layout to torch tensor so we can re-use it
torch::stable::Tensor layout_cpu = torch::stable::empty(
{num_experts, layout_width}, torch::headeronly::ScalarType::Int,
std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU));
auto cpu_opts =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
torch::Tensor layout_cpu =
torch::empty({num_experts, layout_width}, cpu_opts);
int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>();
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
for (int i = 0; i < num_experts; ++i) {
std::memcpy(layout_data + i * layout_width, // dst (int32*)
&layout_B_reordered, // src (LayoutType*)
sizeof(LayoutType)); // number of bytes
}
torch::stable::Tensor packed_layout =
torch::stable::to(layout_cpu, b_tensors.device(),
/*non_blocking=*/false);
torch::Tensor packed_layout =
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
return {b_tensors_packed, packed_layout};
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm));
m.impl("cutlass_encode_and_reorder_int4b_grouped",
TORCH_BOX(&encode_and_reorder_int4b));
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", &mm);
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
}
} // namespace vllm::cutlass_w4a8_moe
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -3,12 +3,14 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
//
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
#include "core/registration.h"
#include "cutlass/cutlass.h"
#include <limits>
@@ -159,31 +161,31 @@ struct W4A8GemmKernel {
using StrideD = typename GemmKernelShuffled::StrideD;
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
static torch::stable::Tensor mm(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type) {
static torch::Tensor mm(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size,
torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type) {
// TODO: param validation
int m = A.size(0);
int k = A.size(1);
int n = B.size(1);
// safely cast group_size to int
STD_TORCH_CHECK(
group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
int const group_size_int = static_cast<int>(group_size);
// Allocate output
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto device = A.device();
auto stream = get_current_cuda_stream(device.index());
torch::stable::Tensor D = torch::stable::empty(
{m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device);
auto stream = at::cuda::getCurrentCUDAStream(device.index());
torch::Tensor D =
torch::empty({m, n}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<ElementD>)
.device(device));
// prepare arg pointers
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
@@ -235,9 +237,9 @@ struct W4A8GemmKernel {
// Workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::stable::Tensor workspace = torch::stable::empty(
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
device);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
// Run GEMM
GemmShuffled gemm;
@@ -267,14 +269,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
torch::stable::Tensor mm_dispatch(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
const std::string& schedule) {
torch::Tensor mm_dispatch(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size,
torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type,
const std::string& schedule) {
if (schedule == "256x128_1x1x1") {
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
channel_scales, token_scales,
@@ -316,18 +318,17 @@ torch::stable::Tensor mm_dispatch(
channel_scales, token_scales,
maybe_out_type);
}
STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
return {};
}
torch::stable::Tensor mm(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
torch::Tensor mm(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size, torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
// requested a specific schedule
if (maybe_schedule) {
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
@@ -377,15 +378,14 @@ torch::stable::Tensor mm(
// ----------------------------------------------------------------------------
// Pre-processing utils
// ----------------------------------------------------------------------------
torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
STD_TORCH_CHECK(scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(scales.is_contiguous());
STD_TORCH_CHECK(scales.is_cuda());
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(scales.is_cuda());
auto packed_scales =
torch::stable::empty({scales.numel() * ScalePackSize},
scales.scalar_type(), std::nullopt, scales.device());
auto packed_scales = torch::empty(
{scales.numel() * ScalePackSize},
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
auto packed_scales_ptr =
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
@@ -396,16 +396,15 @@ torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
return packed_scales;
}
torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(B.dim() == 2);
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
TORCH_CHECK(B.dtype() == torch::kInt32);
TORCH_CHECK(B.dim() == 2);
torch::stable::Tensor B_packed = torch::stable::empty_like(B);
torch::Tensor B_packed = torch::empty_like(B);
int k = B.size(0) * PackFactor; // logical k
int n = B.size(1);
STD_TORCH_CHECK((n * k) % 32 == 0,
"need multiples of 32 int4s for 16B chunks");
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
@@ -416,17 +415,16 @@ torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
n * k);
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
TORCH_CHECK(ok, "unified_encode_int4b failed");
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
return B_packed;
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm));
m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8));
m.impl("cutlass_encode_and_reorder_int4b",
TORCH_BOX(&encode_and_reorder_int4b));
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_mm", &mm);
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
}
} // namespace vllm::cutlass_w4a8
} // namespace vllm::cutlass_w4a8

View File

@@ -14,15 +14,16 @@
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -117,19 +118,17 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm
void silu_and_mul_nvfp4_quant_sm1xxa(
torch::stable::Tensor& output, // [..., d]
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input, // [..., 2 * d]
torch::stable::Tensor& input_sf) {
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
torch::Tensor& output_sf,
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& input_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1) / 2;
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -137,9 +136,8 @@ void silu_and_mul_nvfp4_quant_sm1xxa(
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
@@ -151,7 +149,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_STABLE_DISPATCH_HALF_TYPES(
VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());

View File

@@ -14,12 +14,14 @@
* limitations under the License.
*/
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "core/registration.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
@@ -120,7 +122,7 @@ __global__ void __get_group_gemm_starts(
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
@@ -148,64 +150,50 @@ __global__ void __get_group_gemm_starts(
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
const torch::stable::Tensor& b_starts,
const torch::stable::Tensor& out_starts,
const torch::stable::Tensor& a_scales_starts,
const torch::stable::Tensor& b_scales_starts,
const torch::stable::Tensor& alpha_starts,
const torch::stable::Tensor& layout_sfa,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& alphas,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& sf_offsets,
torch::stable::Tensor const& problem_sizes,
int M, int N, int K) {
void run_get_group_gemm_starts(
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& alphas,
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes, int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
STD_TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t,
torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA,
LayoutSFB, ScaleConfig)
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t,
torch::headeronly::ScalarType::Half, half,
LayoutSFA, LayoutSFB, ScaleConfig)
cutlass::float_ue4m3_t, torch::kFloat16,
half, LayoutSFA, LayoutSFB, ScaleConfig)
else {
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm_sm100(
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
@@ -284,40 +272,20 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -340,7 +308,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device_index();
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
@@ -382,35 +350,32 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
void run_fp4_blockwise_scaled_group_mm_sm120(
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
@@ -481,40 +446,20 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -535,7 +480,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device_index();
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
@@ -578,36 +523,33 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
if (version_num >= 120 && version_num < 130) {
@@ -625,7 +567,7 @@ void run_fp4_blockwise_scaled_group_mm(
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100 or 120");
@@ -633,31 +575,26 @@ void run_fp4_blockwise_scaled_group_mm(
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#endif
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets) {
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation
@@ -665,34 +602,30 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
STD_TORCH_CHECK(
a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
STD_TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
STD_TORCH_CHECK(problem_sizes.dim() == 2,
"problem_sizes must be a 2D tensor");
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
STD_TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32.");
TORCH_CHECK(a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
if (output.scalar_type() == torch::kBFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
@@ -700,7 +633,7 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
int32_t version_num = get_sm_version_num();
if (version_num >= 120 && version_num < 130) {
STD_TORCH_CHECK_NOT_IMPLEMENTED(
TORCH_CHECK_NOT_IMPLEMENTED(
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
output.scalar_type());
}
@@ -710,7 +643,7 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
expert_offsets, sf_offsets, M, N, K);
}
#else
STD_TORCH_CHECK_NOT_IMPLEMENTED(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
@@ -718,6 +651,6 @@ void cutlass_fp4_group_mm(torch::stable::Tensor& output,
#endif
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm));
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
}

View File

@@ -14,15 +14,16 @@
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
@@ -326,28 +327,25 @@ void quant_impl(void* output, void* output_scale, void* input,
} // namespace vllm
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = torch::headeronly::ScalarType::Half;
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
constexpr auto FLOAT = torch::headeronly::ScalarType::Float;
constexpr auto INT = torch::headeronly::ScalarType::Int;
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
// Common validation for fp4 experts quantization entry points.
static void validate_fp4_experts_quant_inputs(
torch::stable::Tensor const& output,
torch::stable::Tensor const& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
torch::Tensor const& output, torch::Tensor const& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
int64_t k) {
CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale");
@@ -356,42 +354,41 @@ static void validate_fp4_experts_quant_inputs(
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
STD_TORCH_CHECK(output.dim() == 2);
STD_TORCH_CHECK(output_scale.dim() == 2);
STD_TORCH_CHECK(input.dim() == 2);
STD_TORCH_CHECK(input_global_scale.dim() == 1);
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
STD_TORCH_CHECK(output.scalar_type() == UINT8);
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output.size(0) == m_topk);
STD_TORCH_CHECK(output.size(1) == k / 2);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
}
void scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
auto k = input.size(1);
@@ -400,11 +397,11 @@ void scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
VLLM_STABLE_DISPATCH_HALF_TYPES(
VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
@@ -416,15 +413,14 @@ void scaled_fp4_experts_quant_sm1xxa(
}
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
// Input has gate || up layout, so k = input.size(1) / 2
auto k_times_2 = input.size(1);
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2;
validate_fp4_experts_quant_inputs(output, output_scale, input,
@@ -432,11 +428,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
VLLM_STABLE_DISPATCH_HALF_TYPES(
VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(

View File

@@ -0,0 +1,134 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
torch::Tensor& output_sf,
torch::Tensor& input,
torch::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::empty(
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
torch::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::empty(
{sf_m, sf_n},
torch::TensorOptions().device(device).dtype(torch::kInt32));
} else {
output_sf = torch::empty(
{m, n / CVT_FP4_SF_VEC_SIZE},
torch::TensorOptions().device(device).dtype(torch::kUInt8));
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
torch::Tensor& input, torch::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -14,16 +14,16 @@
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -173,19 +173,18 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int32_t m = input.size(0);
int32_t n = input.size(1);
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -193,9 +192,8 @@ void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
@@ -215,15 +213,15 @@ void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
} else {
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
@@ -231,15 +229,15 @@ void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(
m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}
}

View File

@@ -0,0 +1,67 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
#endif
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
#endif
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& A_sf,
const torch::Tensor& B_sf,
const torch::Tensor& alpha) {
// Make sure were on As device.
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) {
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
}

View File

@@ -14,9 +14,10 @@
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp"
@@ -126,9 +127,8 @@ struct Fp4GemmSm100 {
template <typename Config>
typename Config::Gemm::Arguments args_from_options(
torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
int64_t M, int64_t N, int64_t K) {
using ElementA = typename Config::Gemm::ElementA;
using ElementB = typename Config::Gemm::ElementB;
@@ -174,20 +174,19 @@ typename Config::Gemm::Arguments args_from_options(
}
template <typename Config>
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
cudaStream_t stream) {
typename Config::Gemm gemm;
auto arguments =
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, A.device());
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -198,13 +197,12 @@ void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
// Dispatch function to select appropriate config based on M
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 16) {
@@ -224,65 +222,61 @@ void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
#else
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.size(0), "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
auto const m = A.size(0);
auto const n = B.size(0);
auto const k = A.size(1) * 2;
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
constexpr int alignment = 32;
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
@@ -291,34 +285,33 @@ void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
// integer.
int rounded_k = round_up(k / 16, 4);
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
B_sf.size(1), ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.size(1), ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.size(1), ")");
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
"x", B_sf.sizes()[1], ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.sizes()[1], ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.sizes()[1], ")");
auto out_dtype = D.scalar_type();
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
auto out_dtype = D.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
if (out_dtype == torch::headeronly::ScalarType::Half) {
if (out_dtype == at::ScalarType::Half) {
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
k, stream);
} else if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
} else if (out_dtype == at::ScalarType::BFloat16) {
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
m, n, k, stream);
} else {
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (",
out_dtype, ")");
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
")");
}
}

View File

@@ -14,9 +14,10 @@
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp"
@@ -33,20 +34,19 @@
using namespace cute;
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>;
@@ -109,13 +109,12 @@ struct Fp4GemmSm120 {
};
template <typename Gemm>
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha,
int M, int N, int K) {
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
torch::Tensor const& alpha, int M,
int N, int K) {
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD;
@@ -159,19 +158,18 @@ typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
}
template <typename Gemm>
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int M, int N, int K,
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
torch::Tensor const& alpha, int M, int N, int K,
cudaStream_t stream) {
Gemm gemm;
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, A.device());
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -180,13 +178,12 @@ void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
@@ -197,13 +194,12 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
}
}
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
@@ -214,12 +210,11 @@ void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
}
}
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
@@ -227,25 +222,24 @@ void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.size(0), "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
auto const m = A.size(0);
auto const n = B.size(0);
auto const k = A.size(1) * 2;
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
constexpr int alignment = 32;
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
@@ -254,39 +248,38 @@ void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
// integer.
int rounded_k = round_up(k / 16, 4);
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
B_sf.size(1), ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.size(1), ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.size(1), ")");
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
"x", B_sf.sizes()[1], ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.sizes()[1], ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.sizes()[1], ")");
auto out_dtype = D.scalar_type();
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
auto out_dtype = D.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
if (out_dtype == at::ScalarType::BFloat16) {
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
} else if (out_dtype == at::ScalarType::Half) {
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else {
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
}
#else
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
}
}

View File

@@ -20,7 +20,7 @@
#include <cuda_fp8.h>
#include <utility>
#include "cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090

View File

@@ -1,169 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../dispatch_utils.h"
#include "quant_conversions.cuh"
#include "../w8a8/fp8/common.cuh"
namespace vllm {
// Logic: one thread block per (token, group) pair
template <typename scalar_t, typename scalar_out_t, bool is_scale_transposed,
int32_t group_size>
__global__ void silu_and_mul_per_block_quant_kernel(
scalar_out_t* __restrict__ out, // Output: [num_tokens, hidden_size] in
// FP8/INT8
float* __restrict__ scales, // Output: [num_tokens, hidden_size /
// group_size] or [hidden_size / group_size,
// num_tokens]
scalar_t const* __restrict__ input, // Input: [num_tokens, hidden_size * 2]
float const* scale_ub, // Optional scale upper bound
int32_t const hidden_size // Output hidden size (input is 2x this)
) {
static_assert((group_size & (group_size - 1)) == 0,
"group_size must be a power of 2 for correct reduction");
// Grid: (num_tokens, num_groups)
int const token_idx = blockIdx.x;
int const group_idx = blockIdx.y;
int const tid = threadIdx.x; // tid in [0, group_size)
int const num_tokens = gridDim.x;
// Input layout: [gate || up] concatenated along last dimension
int const input_stride = hidden_size * 2;
int const group_start = group_idx * group_size;
// Pointers to this token's data
scalar_t const* token_input_gate =
input + token_idx * input_stride + group_start;
scalar_t const* token_input_up = token_input_gate + hidden_size;
scalar_out_t* token_output = out + token_idx * hidden_size + group_start;
// Scale pointer for this group
int const num_groups = gridDim.y;
float* group_scale_ptr = is_scale_transposed
? scales + group_idx * num_tokens + token_idx
: scales + token_idx * num_groups + group_idx;
// Shared memory for reduction (compile-time sized)
__shared__ float shared_max[group_size];
// Step 1: Each thread loads one element, computes SiLU, stores in register
float gate = static_cast<float>(token_input_gate[tid]);
float up = static_cast<float>(token_input_up[tid]);
// Compute SiLU(gate) * up
float sigmoid_gate = 1.0f / (1.0f + expf(-gate));
float silu_gate = gate * sigmoid_gate;
float result = silu_gate * up; // Keep in register
// Step 2: Reduce to find group max
shared_max[tid] = fabsf(result);
__syncthreads();
// Power-of-2 reduction (group_size guaranteed to be power of 2)
#pragma unroll
for (int stride = group_size / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]);
}
__syncthreads();
}
// Step 3: Compute scale (thread 0), broadcast via shared memory
if (tid == 0) {
float group_max = shared_max[0];
float const quant_range = quant_type_max_v<scalar_out_t>;
float group_scale = group_max / quant_range;
// Apply scale upper bound if provided
if (scale_ub != nullptr) {
group_scale = fminf(group_scale, *scale_ub);
}
// Use minimum safe scaling factor
group_scale = fmaxf(group_scale, min_scaling_factor<scalar_out_t>::val());
// Store scale to global memory
*group_scale_ptr = group_scale;
// Reuse shared_max[0] to broadcast scale
shared_max[0] = group_scale;
}
__syncthreads();
float group_scale = shared_max[0];
// Step 4: Quantize and write output
token_output[tid] =
vllm::ScaledQuant<scalar_out_t, false>::quant_fn(result, group_scale);
}
} // namespace vllm
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
"Input must be FP16 or BF16");
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size);
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
int32_t hidden_size = out.size(-1);
auto num_tokens = input.size(0);
int32_t num_groups = hidden_size / group_size;
TORCH_CHECK(input.size(-1) == hidden_size * 2,
"input last dim must be 2x output hidden_size");
TORCH_CHECK(hidden_size % group_size == 0,
"hidden_size must be divisible by group_size");
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens, num_groups);
dim3 block(group_size);
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_in_t = scalar_t;
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_out_t = scalar_t;
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
vllm::silu_and_mul_per_block_quant_kernel<
scalar_in_t, scalar_out_t, transpose_scale, gs>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_out_t>(),
scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
});
});
});
});
}

View File

@@ -154,7 +154,6 @@ struct MacheteCollectiveMma {
struct DispatchPolicy {
constexpr static int Stages = PipelineStages;
using ClusterShape = ClusterShape_MNK;
using ArchTag = arch::Sm90;
using Schedule = KernelScheduleType;
};

View File

@@ -108,15 +108,6 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],

View File

@@ -591,9 +591,6 @@ torch::Tensor marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -327,9 +327,6 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (s_type == vllm::kFE8M0fnu) {
// MXFP8: FP8 weights with e8m0 microscaling block scales
static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -337,7 +334,6 @@ __global__ void Marlin(
}
constexpr bool is_a_8bit = a_type.size_bits() == 8;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
if constexpr (!is_a_8bit) {
static_assert(std::is_same<scalar_t, c_scalar_t>::value);
}
@@ -347,7 +343,7 @@ __global__ void Marlin(
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
is_a_8bit || b_type == vllm::kFE4M3fn ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -559,8 +555,9 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
@@ -1000,7 +997,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (!is_8bit_scale) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
@@ -1009,7 +1006,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (!is_8bit_scale) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
@@ -1210,7 +1207,7 @@ __global__ void Marlin(
}
}
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
if constexpr (b_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -2,10 +2,9 @@
// clang-format will break include orders
// clang-format off
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
@@ -26,14 +25,14 @@
namespace vllm::c3x {
static inline cute::Shape<int, int, int, int> get_problem_shape(
torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
torch::Tensor const& a, torch::Tensor const& b) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
return {m, n, k, 1};
}
template <typename GemmKernel>
void cutlass_gemm_caller(
torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
@@ -51,20 +50,19 @@ void cutlass_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, device);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(device);
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = get_current_cuda_stream(device.index());
auto stream = at::cuda::getCurrentCUDAStream(device.index());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC;

View File

@@ -4,12 +4,13 @@
namespace vllm {
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
if (azp) {
return cutlass_scaled_mm_sm90_int8_epilogue<
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,

View File

@@ -0,0 +1,23 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,7 +1,5 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -132,10 +130,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
@@ -202,11 +200,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
}
template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());

View File

@@ -0,0 +1,23 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,7 +1,5 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -140,10 +138,10 @@ struct sm120_blockwise_fp8_config_M64 {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -198,11 +196,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
}
template <typename OutType>
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
int M = a.size(0);
if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;

View File

@@ -0,0 +1,24 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,7 +1,5 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -103,10 +101,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -122,7 +120,7 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
StrideA a_stride;
StrideB b_stride;
@@ -163,11 +161,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
}
template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
// TODO: better heuristics
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>,

View File

@@ -1,57 +1,52 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/all.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias,
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
TORCH_CHECK(a.dtype() == torch::kInt8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
int32_t version_num = get_sm_version_num();
STD_TORCH_CHECK(
TORCH_CHECK(
false, "Int8 not supported on SM", version_num,
". Use FP8 quantization instead, or run on older arch (SM < 100).");
}
}
} else {
STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
STD_TORCH_CHECK(
TORCH_CHECK(
a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
"a_scale_group_shape must be [1, 128].");
STD_TORCH_CHECK(
TORCH_CHECK(
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128].");
}
STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales);
}
}

Some files were not shown because too many files have changed in this diff Show More