Compare commits
152 Commits
v0.19.0
...
v0.19.1rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5506435419 | ||
|
|
311c981647 | ||
|
|
21d7ecc5b0 | ||
|
|
4729b90838 | ||
|
|
8b141ed8c3 | ||
|
|
2ad7c0335f | ||
|
|
201d2ea5bf | ||
|
|
103f0de565 | ||
|
|
32e0c0bfa2 | ||
|
|
4a06e1246e | ||
|
|
3bc2734dd0 | ||
|
|
1f5ec2889c | ||
|
|
ee3cf45739 | ||
|
|
05e68e1f81 | ||
|
|
771913e4a0 | ||
|
|
71a9125c67 | ||
|
|
66e86f1dbd | ||
|
|
bb39382b2b | ||
|
|
7b743ba953 | ||
|
|
188defbd0b | ||
|
|
08ed2b9688 | ||
|
|
ecd5443dbc | ||
|
|
58262dec6e | ||
|
|
cb3935a8fc | ||
|
|
82a006beeb | ||
|
|
a9b4f07ba2 | ||
|
|
d9408ffba3 | ||
|
|
16a65e4173 | ||
|
|
c0817e4d39 | ||
|
|
dfe5e31689 | ||
|
|
2ce3d0ce36 | ||
|
|
4eefbf9609 | ||
|
|
551b3fb39f | ||
|
|
c6f722b93e | ||
|
|
9bd7231106 | ||
|
|
73f48ce559 | ||
|
|
3aab680e3e | ||
|
|
5a2d420c17 | ||
|
|
5f96f9aff1 | ||
|
|
694449050f | ||
|
|
6241521dd2 | ||
|
|
1785dc5501 | ||
|
|
54500546ac | ||
|
|
de5e6c44c6 | ||
|
|
cb268e4e55 | ||
|
|
6183cae1bd | ||
|
|
c09ad767cd | ||
|
|
c9a9db0e02 | ||
|
|
cbe7d18096 | ||
|
|
db5d0719e1 | ||
|
|
dc0428ebb8 | ||
|
|
148c2072ec | ||
|
|
2f5c3c1ec0 | ||
|
|
fa246d5231 | ||
|
|
7cf56a59a2 | ||
|
|
5e30e9b9a9 | ||
|
|
582340f273 | ||
|
|
992368522f | ||
|
|
58ee614221 | ||
|
|
f9f6a9097a | ||
|
|
c75a313824 | ||
|
|
4f6eed3bd4 | ||
|
|
36d7f19897 | ||
|
|
2d725b89c5 | ||
|
|
ef53395e2c | ||
|
|
eb47454987 | ||
|
|
116f4be405 | ||
|
|
7b01d97a22 | ||
|
|
17b72fd1c8 | ||
|
|
c49497726b | ||
|
|
cb0b443274 | ||
|
|
40bb175027 | ||
|
|
0fab52f0aa | ||
|
|
91e4521f9f | ||
|
|
31a719bcd3 | ||
|
|
2e56975657 | ||
|
|
36f1dc19ae | ||
|
|
3dc01ef352 | ||
|
|
cc671cb110 | ||
|
|
856589ed9a | ||
|
|
517b769b58 | ||
|
|
d9b90a07ac | ||
|
|
598190aac3 | ||
|
|
b779eb3363 | ||
|
|
077a9a8e37 | ||
|
|
07edd551cc | ||
|
|
7c080dd3c5 | ||
|
|
0dd25a44ea | ||
|
|
3896e021a0 | ||
|
|
b6e636c12c | ||
|
|
f1ff50c86c | ||
|
|
757068dc65 | ||
|
|
7337ff7f03 | ||
|
|
5869f69c5f | ||
|
|
4dfad17ed1 | ||
|
|
e8057c00bc | ||
|
|
7430389669 | ||
|
|
202f147cf2 | ||
|
|
ea7bfde6e4 | ||
|
|
d71a15041f | ||
|
|
abdbb68386 | ||
|
|
0c63739135 | ||
|
|
719735d6c5 | ||
|
|
aae3e688f8 | ||
|
|
7d65463528 | ||
|
|
8278825b57 | ||
|
|
acf7292bf2 | ||
|
|
ce884756f0 | ||
|
|
d9d21eb8e3 | ||
|
|
f09daea261 | ||
|
|
42318c840b | ||
|
|
1ac6694297 | ||
|
|
6cc7abdc66 | ||
|
|
d53cb9cb8e | ||
|
|
44eef0ca1e | ||
|
|
b9cdc85207 | ||
|
|
3e802e8786 | ||
|
|
350af48e14 | ||
|
|
e31915063d | ||
|
|
29e48707e8 | ||
|
|
4ac227222f | ||
|
|
bb51d5b40d | ||
|
|
93b3ec1585 | ||
|
|
e812bf70bd | ||
|
|
bcc6f67447 | ||
|
|
1fc69f59bb | ||
|
|
d9c7db18da | ||
|
|
12701e8af2 | ||
|
|
494636b29d | ||
|
|
ab1a6a43fa | ||
|
|
b5e608258e | ||
|
|
2c734ed0e0 | ||
|
|
3b1dbaad4e | ||
|
|
b4a2f3ac36 | ||
|
|
8e6293e838 | ||
|
|
dbdd9ae067 | ||
|
|
e8b055a5ac | ||
|
|
246dc7d864 | ||
|
|
7c3f88b2a8 | ||
|
|
6557f4937f | ||
|
|
677424c7ac | ||
|
|
1031c84c36 | ||
|
|
7e76af14fa | ||
|
|
3683fe6c06 | ||
|
|
cc06b4e86b | ||
|
|
03ac6ca895 | ||
|
|
a08b7733fd | ||
|
|
85c0950b1f | ||
|
|
57861ae48d | ||
|
|
ac30a8311e | ||
|
|
63babd17f1 | ||
|
|
fec5aeca12 |
@@ -5,6 +5,7 @@ steps:
|
||||
depends_on: []
|
||||
device: amd_cpu
|
||||
no_plugin: true
|
||||
soft_fail: true
|
||||
commands:
|
||||
- >
|
||||
docker build
|
||||
@@ -20,11 +21,3 @@ steps:
|
||||
- docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
retry:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 1
|
||||
- exit_status: -10 # Agent was lost
|
||||
limit: 1
|
||||
- exit_status: 1 # Machine occasionally fail
|
||||
limit: 1
|
||||
|
||||
@@ -13,12 +13,14 @@ steps:
|
||||
- tests/kernels/attention/test_cpu_attn.py
|
||||
- tests/kernels/moe/test_cpu_fused_moe.py
|
||||
- tests/kernels/test_onednn.py
|
||||
- tests/kernels/test_awq_int4_to_int8.py
|
||||
commands:
|
||||
- |
|
||||
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
|
||||
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
|
||||
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
|
||||
pytest -x -v -s tests/kernels/test_onednn.py"
|
||||
pytest -x -v -s tests/kernels/test_onednn.py
|
||||
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
|
||||
|
||||
- label: CPU-Compatibility Tests
|
||||
depends_on: []
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"ignore-eos": "",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -127,4 +128,4 @@
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
"hf_split": "test",
|
||||
"no_stream": "",
|
||||
"no_oversample": "",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"ignore-eos": "",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"ignore-eos": "",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -47,6 +48,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -73,6 +75,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -100,6 +103,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -127,6 +131,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -151,6 +156,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -30,6 +31,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -47,6 +49,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -67,6 +70,7 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#!/bin/bash
|
||||
set -euox pipefail
|
||||
export VLLM_CPU_CI_ENV=0
|
||||
export VLLM_CPU_KVCACHE_SPACE=1 # avoid OOM
|
||||
|
||||
echo "--- PP+TP"
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 --max-model-len=4096 &
|
||||
server_pid=$!
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||
vllm bench serve \
|
||||
@@ -23,7 +24,7 @@ if [ "$failed_req" -ne 0 ]; then
|
||||
fi
|
||||
|
||||
echo "--- DP+TP"
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
|
||||
server_pid=$!
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||
vllm bench serve \
|
||||
|
||||
@@ -239,13 +239,29 @@ fi
|
||||
# --- Docker housekeeping ---
|
||||
cleanup_docker
|
||||
|
||||
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin "$REGISTRY"
|
||||
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 936637512419.dkr.ecr.us-east-1.amazonaws.com
|
||||
|
||||
# --- Build or pull test image ---
|
||||
if [[ -n "${IMAGE_TAG_XPU:-}" ]]; then
|
||||
echo "Using prebuilt XPU image: ${IMAGE_TAG_XPU}"
|
||||
docker pull "${IMAGE_TAG_XPU}"
|
||||
IMAGE="${IMAGE_TAG_XPU:-${image_name}}"
|
||||
|
||||
echo "Using image: ${IMAGE}"
|
||||
|
||||
if docker image inspect "${IMAGE}" >/dev/null 2>&1; then
|
||||
echo "Image already exists locally, skipping pull"
|
||||
else
|
||||
echo "Using prebuilt XPU image: ${image_name}"
|
||||
docker pull "${image_name}"
|
||||
echo "Image not found locally, waiting for lock..."
|
||||
|
||||
flock /tmp/docker-pull.lock bash -c "
|
||||
if docker image inspect '${IMAGE}' >/dev/null 2>&1; then
|
||||
echo 'Image already pulled by another runner'
|
||||
else
|
||||
echo 'Pulling image...'
|
||||
timeout 900 docker pull '${IMAGE}'
|
||||
fi
|
||||
"
|
||||
|
||||
echo "Pull step completed"
|
||||
fi
|
||||
|
||||
remove_docker_container() {
|
||||
|
||||
@@ -42,6 +42,7 @@ docker run \
|
||||
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192
|
||||
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
||||
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
||||
python3 examples/basic/offline_inference/generate.py --model OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc --block-size 64 --enforce-eager --max-model-len 8192
|
||||
cd tests
|
||||
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
|
||||
pytest -v -s v1/engine
|
||||
|
||||
@@ -790,7 +790,7 @@ steps:
|
||||
- tests/kernels/helion/
|
||||
- vllm/platforms/rocm.py
|
||||
commands:
|
||||
- pip install helion
|
||||
- pip install helion==0.3.3
|
||||
- pytest -v -s kernels/helion/
|
||||
|
||||
|
||||
|
||||
@@ -2,14 +2,6 @@ group: Benchmarks
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: Benchmarks
|
||||
timeout_in_minutes: 20
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
source_file_dependencies:
|
||||
- benchmarks/
|
||||
commands:
|
||||
- bash scripts/run-benchmarks.sh
|
||||
|
||||
- label: Benchmarks CLI Test
|
||||
timeout_in_minutes: 20
|
||||
source_file_dependencies:
|
||||
|
||||
@@ -72,6 +72,7 @@ steps:
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
|
||||
- tests/compile/passes/test_fusion_attn.py
|
||||
- tests/compile/passes/test_mla_attn_quant_fusion.py
|
||||
- tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/fullgraph/test_full_graph.py
|
||||
@@ -79,6 +80,7 @@ steps:
|
||||
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
|
||||
- pytest -v -s tests/compile/passes/test_mla_attn_quant_fusion.py
|
||||
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_devices=2 is not set
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
|
||||
@@ -224,6 +224,20 @@ steps:
|
||||
commands:
|
||||
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 $IMAGE_TAG "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code"
|
||||
|
||||
- label: MessageQueue TCP Multi-Node (2 GPUs)
|
||||
timeout_in_minutes: 10
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 1
|
||||
num_nodes: 2
|
||||
no_plugin: true
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/device_communicators/shm_broadcast.py
|
||||
- vllm/distributed/parallel_state.py
|
||||
- tests/distributed/test_mq_tcp_multinode.py
|
||||
commands:
|
||||
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 1 $IMAGE_TAG "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py" "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py"
|
||||
|
||||
- label: Distributed NixlConnector PD accuracy (4 GPUs)
|
||||
timeout_in_minutes: 30
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@@ -294,3 +308,23 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pp_cudagraph.py
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- label: RayExecutorV2 (4 GPUs)
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 4
|
||||
source_file_dependencies:
|
||||
- vllm/v1/executor/ray_executor_v2.py
|
||||
- vllm/v1/executor/abstract.py
|
||||
- vllm/v1/executor/multiproc_executor.py
|
||||
- tests/distributed/test_ray_v2_executor.py
|
||||
- tests/distributed/test_ray_v2_executor_e2e.py
|
||||
- tests/distributed/test_pipeline_parallel.py
|
||||
- tests/basic_correctness/test_basic_correctness.py
|
||||
commands:
|
||||
- export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1
|
||||
- export NCCL_CUMEM_HOST_ENABLE=0
|
||||
- pytest -v -s distributed/test_ray_v2_executor.py
|
||||
- pytest -v -s distributed/test_ray_v2_executor_e2e.py
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py -k "ray"
|
||||
- TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray"
|
||||
|
||||
@@ -13,8 +13,8 @@ steps:
|
||||
- pytest -v -s distributed/test_eplb_algo.py
|
||||
- pytest -v -s distributed/test_eplb_utils.py
|
||||
|
||||
- label: EPLB Execution
|
||||
timeout_in_minutes: 20
|
||||
- label: EPLB Execution # 17min
|
||||
timeout_in_minutes: 27
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 4
|
||||
source_file_dependencies:
|
||||
|
||||
@@ -2,6 +2,16 @@ group: Kernels
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: vLLM IR Tests
|
||||
timeout_in_minutes: 10
|
||||
working_dir: "/vllm-workspace/"
|
||||
source_file_dependencies:
|
||||
- vllm/ir
|
||||
- vllm/kernels
|
||||
commands:
|
||||
- pytest -v -s tests/ir
|
||||
- pytest -v -s tests/kernels/ir
|
||||
|
||||
- label: Kernels Core Operation Test
|
||||
timeout_in_minutes: 75
|
||||
source_file_dependencies:
|
||||
@@ -19,6 +29,7 @@ steps:
|
||||
- vllm/v1/attention
|
||||
# TODO: remove this dependency (https://github.com/vllm-project/vllm/issues/32267)
|
||||
- vllm/model_executor/layers/attention
|
||||
- vllm/utils/flashinfer.py
|
||||
- tests/kernels/attention
|
||||
commands:
|
||||
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
@@ -129,7 +140,7 @@ steps:
|
||||
- vllm/utils/import_utils.py
|
||||
- tests/kernels/helion/
|
||||
commands:
|
||||
- pip install helion
|
||||
- pip install helion==0.3.3
|
||||
- pytest -v -s kernels/helion/
|
||||
|
||||
|
||||
|
||||
@@ -18,5 +18,6 @@ steps:
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
|
||||
- pytest models/multimodal/generation/test_phi4siglip.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/generation/test_phi4siglip.py
|
||||
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
|
||||
|
||||
30
.github/CODEOWNERS
vendored
30
.github/CODEOWNERS
vendored
@@ -2,16 +2,20 @@
|
||||
# for more info about CODEOWNERS file
|
||||
|
||||
# This lists cover the "core" components of vLLM that require careful review
|
||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng
|
||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng @vadiklyutiy
|
||||
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
|
||||
/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
|
||||
/vllm/model_executor/layers/mamba @tdoublep
|
||||
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516
|
||||
/vllm/model_executor/layers/mamba @tdoublep @tomeras91
|
||||
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516 @vadiklyutiy
|
||||
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
|
||||
/vllm/model_executor/model_loader @22quinn
|
||||
/vllm/model_executor/layers/batch_invariant.py @yewentao256
|
||||
/vllm/ir @ProExpertProg
|
||||
/vllm/kernels/ @ProExpertProg @tjtanaa
|
||||
/vllm/kernels/helion @ProExpertProg @zou3519
|
||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
|
||||
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
@@ -47,9 +51,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/v1/attention @LucasWilkinson @MatthewBonanni
|
||||
/vllm/v1/attention/backend.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill
|
||||
/vllm/v1/attention/backends/mla @pavanimajety
|
||||
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety
|
||||
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety @vadiklyutiy
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516
|
||||
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516 @vadiklyutiy
|
||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
|
||||
/vllm/v1/sample @22quinn @houseroad @njhill
|
||||
/vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni
|
||||
@@ -71,8 +75,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
|
||||
/tests/evals @mgoin
|
||||
/tests/evals @mgoin @vadiklyutiy
|
||||
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/kernels/ir @ProExpertProg @tjtanaa
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
|
||||
@@ -82,7 +87,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
|
||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||
/tests/lora @jeejeelee
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep @tomeras91
|
||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||
/tests/v1/kv_connector @ApostaC @orozery
|
||||
/tests/v1/kv_offload @ApostaC @orozery
|
||||
@@ -126,9 +131,14 @@ mkdocs.yaml @hmellor
|
||||
/vllm/platforms/xpu.py @jikunshang
|
||||
/docker/Dockerfile.xpu @jikunshang
|
||||
|
||||
# Nemotron-specific files
|
||||
/vllm/model_executor/models/*nemotron* @tomeras91
|
||||
/vllm/transformers_utils/configs/*nemotron* @tomeras91
|
||||
/tests/**/*nemotron* @tomeras91
|
||||
|
||||
# Qwen-specific files
|
||||
/vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow
|
||||
/vllm/model_executor/models/qwen* @sighingnow
|
||||
/vllm/model_executor/models/qwen* @sighingnow @vadiklyutiy
|
||||
/vllm/transformers_utils/configs/qwen* @sighingnow @vadiklyutiy
|
||||
|
||||
# MTP-specific files
|
||||
/vllm/model_executor/models/deepseek_mtp.py @luccafong
|
||||
@@ -144,7 +154,7 @@ mkdocs.yaml @hmellor
|
||||
# Kernels
|
||||
/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep
|
||||
/vllm/v1/attention/ops/triton_unified_attention.py @tdoublep
|
||||
/vllm/model_executor/layers/fla @ZJY0516
|
||||
/vllm/model_executor/layers/fla @ZJY0516 @vadiklyutiy
|
||||
|
||||
# ROCm related: specify owner with write access to notify AMD folks for careful code review
|
||||
/vllm/**/*rocm* @tjtanaa
|
||||
|
||||
7
.github/workflows/pre-commit.yml
vendored
7
.github/workflows/pre-commit.yml
vendored
@@ -28,6 +28,7 @@ jobs:
|
||||
});
|
||||
|
||||
const hasReadyLabel = pr.labels.some(l => l.name === 'ready');
|
||||
const hasVerifiedLabel = pr.labels.some(l => l.name === 'verified');
|
||||
|
||||
const { data: mergedPRs } = await github.rest.search.issuesAndPullRequests({
|
||||
q: `repo:${context.repo.owner}/${context.repo.repo} is:pr is:merged author:${pr.user.login}`,
|
||||
@@ -35,10 +36,10 @@ jobs:
|
||||
});
|
||||
const mergedCount = mergedPRs.total_count;
|
||||
|
||||
if (hasReadyLabel || mergedCount >= 4) {
|
||||
core.info(`Check passed: ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
|
||||
if (hasReadyLabel || hasVerifiedLabel || mergedCount >= 4) {
|
||||
core.info(`Check passed: verified label=${hasVerifiedLabel}, ready label=${hasReadyLabel}, 4+ merged PRs=${mergedCount >= 4}`);
|
||||
} else {
|
||||
core.setFailed(`PR must have the 'ready' label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
|
||||
core.setFailed(`PR must have the 'verified' or 'ready' (which also triggers tests) label or the author must have at least 4 merged PRs (found ${mergedCount}).`);
|
||||
}
|
||||
|
||||
pre-commit:
|
||||
|
||||
@@ -39,7 +39,7 @@ repos:
|
||||
rev: 0.11.1
|
||||
hooks:
|
||||
- id: pip-compile
|
||||
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
|
||||
args: [requirements/test.in, -c, requirements/common.txt, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"]
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- id: pip-compile
|
||||
alias: pip-compile-rocm
|
||||
|
||||
600
CMakeLists.txt
600
CMakeLists.txt
@@ -309,7 +309,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
set(CUTLASS_REVISION "v4.2.1")
|
||||
set(CUTLASS_REVISION "v4.4.2")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@@ -340,10 +340,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
@@ -490,185 +488,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||
# kernels for the remaining archs that are not already built for 3x.
|
||||
# (Build 8.9 for FP8)
|
||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
|
||||
# subtract out the archs that are already built for 3x
|
||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||
if (SCALED_MM_2X_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
||||
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
||||
else()
|
||||
if (SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
||||
" for and covered by scaled_mm_c3x")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# FP4 Archs and flags
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# CUTLASS MLA Archs and flags
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
@@ -693,55 +512,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(MLA_ARCHS)
|
||||
endif()
|
||||
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
|
||||
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
@@ -787,36 +557,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
|
||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building moe_data as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
@@ -887,34 +627,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
|
||||
)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
|
||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND W4A8_ARCHS)
|
||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Hadacore kernels
|
||||
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
@@ -964,7 +676,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
|
||||
#
|
||||
set(VLLM_STABLE_EXT_SRC
|
||||
"csrc/libtorch_stable/torch_bindings.cpp")
|
||||
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
@@ -979,6 +696,299 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch)
|
||||
#
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||
# kernels for the remaining archs that are not already built for 3x.
|
||||
# (Build 8.9 for FP8)
|
||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
|
||||
# subtract out the archs that are already built for 3x
|
||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||
if (SCALED_MM_2X_ARCHS)
|
||||
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
||||
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
||||
else()
|
||||
if (SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
||||
" for and covered by scaled_mm_c3x")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS MoE kernels (moved from _C to _C_stable_libtorch)
|
||||
#
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
|
||||
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building moe_data as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
|
||||
#
|
||||
|
||||
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# FP4 Archs and flags
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
# clear FP4_ARCHS
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# W4A8 kernels (moved from _C to _C_stable_libtorch)
|
||||
#
|
||||
|
||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
|
||||
)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
|
||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND W4A8_ARCHS)
|
||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C_stable extension.")
|
||||
define_extension_target(
|
||||
_C_stable_libtorch
|
||||
@@ -987,6 +997,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SOURCES ${VLLM_STABLE_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
@@ -1000,6 +1011,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Needed to use cuda APIs from C-shim
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
USE_CUDA)
|
||||
|
||||
# Needed by CUTLASS kernels
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
endif()
|
||||
|
||||
#
|
||||
@@ -1015,7 +1030,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_wna16.cu"
|
||||
"csrc/moe/grouped_topk_kernels.cu"
|
||||
"csrc/moe/gpt_oss_router_gemm.cu"
|
||||
"csrc/moe/router_gemm.cu")
|
||||
endif()
|
||||
|
||||
|
||||
264
benchmarks/fused_kernels/merge_attn_states_benchmarks.py
Normal file
264
benchmarks/fused_kernels/merge_attn_states_benchmarks.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark: Fused FP8 output quantization in merge_attn_states
|
||||
|
||||
Compares fused vs unfused approaches for producing FP8-quantized merged
|
||||
attention output:
|
||||
1. Fused CUDA -- single CUDA kernel (merge + FP8 quant)
|
||||
2. Fused Triton -- single Triton kernel (merge + FP8 quant)
|
||||
3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant
|
||||
4. Unfused Triton -- Triton merge + torch.compiled FP8 quant
|
||||
|
||||
Usage:
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
|
||||
from vllm.benchmarks.lib.utils import default_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states as merge_attn_states_triton,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096]
|
||||
|
||||
# (label, num_heads, head_size) — num_heads is for TP=1
|
||||
HEAD_CONFIGS = [
|
||||
("DeepSeek-V3 MLA", 128, 128),
|
||||
("Llama-70B", 64, 128),
|
||||
("Llama-8B", 32, 128),
|
||||
]
|
||||
|
||||
TP_SIZES = [1, 2, 4, 8]
|
||||
|
||||
INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
QUANTILES = [0.5, 0.2, 0.8]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def short_dtype(dtype: torch.dtype) -> str:
|
||||
return str(dtype).removeprefix("torch.")
|
||||
|
||||
|
||||
def make_inputs(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Create random prefix/suffix outputs and LSEs."""
|
||||
prefix_output = torch.randn(
|
||||
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
|
||||
)
|
||||
prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
|
||||
suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
|
||||
# Sprinkle some inf values to exercise edge-case paths
|
||||
mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
|
||||
prefix_lse[mask] = float("inf")
|
||||
mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
|
||||
suffix_lse[mask2] = float("inf")
|
||||
return prefix_output, suffix_output, prefix_lse, suffix_lse
|
||||
|
||||
|
||||
def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes):
|
||||
"""Build (num_tokens, num_heads, head_size, dtype_str) config tuples,
|
||||
applying TP division to num_heads and skipping invalid combos."""
|
||||
configs = []
|
||||
for (_, nh, hs), nt, dtype, tp in itertools.product(
|
||||
head_configs, num_tokens_list, input_dtypes, tp_sizes
|
||||
):
|
||||
nh_tp = nh // tp
|
||||
if nh_tp >= 1:
|
||||
configs.append((nt, nh_tp, hs, short_dtype(dtype)))
|
||||
return configs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark merge_attn_states fused FP8 quantization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help=f"Override token counts (default: {NUM_TOKENS_LIST})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Input dtypes (e.g. bfloat16 float16 float32). "
|
||||
f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse args and build configs before decorators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
args = parse_args()
|
||||
|
||||
num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST
|
||||
tp_sizes = args.tp if args.tp else TP_SIZES
|
||||
|
||||
if args.dtype:
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype]
|
||||
else:
|
||||
input_dtypes = INPUT_DTYPES
|
||||
|
||||
configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes)
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_heads", "head_size", "dtype_str"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"],
|
||||
line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")],
|
||||
ylabel="us",
|
||||
plot_name="merge_attn_states FP8 (fused vs unfused)",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
@default_vllm_config()
|
||||
def benchmark(num_tokens, num_heads, head_size, dtype_str, provider):
|
||||
input_dtype = getattr(torch, dtype_str)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs(
|
||||
num_tokens, num_heads, head_size, input_dtype
|
||||
)
|
||||
output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda")
|
||||
|
||||
if provider == "fused_cuda":
|
||||
output = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
|
||||
)
|
||||
fn = lambda: merge_attn_states_cuda(
|
||||
output,
|
||||
prefix_out,
|
||||
prefix_lse,
|
||||
suffix_out,
|
||||
suffix_lse,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
elif provider == "fused_triton":
|
||||
output = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
|
||||
)
|
||||
fn = lambda: merge_attn_states_triton(
|
||||
output,
|
||||
prefix_out,
|
||||
prefix_lse,
|
||||
suffix_out,
|
||||
suffix_lse,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
elif provider == "unfused_cuda":
|
||||
merge_buf = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
quant_fp8 = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
column_major_scales=False,
|
||||
)
|
||||
quant_input = merge_buf.view(-1, head_size)
|
||||
compiled_quant = torch.compile(
|
||||
quant_fp8.forward_native, fullgraph=True, dynamic=False
|
||||
)
|
||||
|
||||
def unfused_fn():
|
||||
merge_attn_states_cuda(
|
||||
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
|
||||
)
|
||||
compiled_quant(quant_input, output_scale)
|
||||
|
||||
fn = unfused_fn
|
||||
else: # unfused_triton
|
||||
merge_buf = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
quant_fp8 = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
column_major_scales=False,
|
||||
)
|
||||
quant_input = merge_buf.view(-1, head_size)
|
||||
compiled_quant = torch.compile(
|
||||
quant_fp8.forward_native, fullgraph=True, dynamic=False
|
||||
)
|
||||
|
||||
def unfused_fn():
|
||||
merge_attn_states_triton(
|
||||
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
|
||||
)
|
||||
compiled_quant(quant_input, output_scale)
|
||||
|
||||
fn = unfused_fn
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES)
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
device_name = current_platform.get_device_name()
|
||||
print(f"Device: {device_name}")
|
||||
print(f"Token counts: {num_tokens_list}")
|
||||
print(f"TP sizes: {tp_sizes}")
|
||||
print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}")
|
||||
print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}")
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.inference_mode():
|
||||
main()
|
||||
211
benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py
Normal file
211
benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class bench_params_t:
|
||||
num_tokens: int
|
||||
hidden_size: int
|
||||
dtype: torch.dtype
|
||||
group_size: int # Changed from list[int] to int
|
||||
|
||||
def description(self):
|
||||
return (
|
||||
f"N {self.num_tokens} "
|
||||
f"x D {self.hidden_size} "
|
||||
f"x DT {self.dtype} "
|
||||
f"x GS {self.group_size}"
|
||||
)
|
||||
|
||||
|
||||
def get_bench_params() -> list[bench_params_t]:
|
||||
"""Test configurations covering common model sizes."""
|
||||
NUM_TOKENS = [16, 128, 512, 2048]
|
||||
HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]]
|
||||
|
||||
combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES)
|
||||
bench_params = list(
|
||||
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
|
||||
)
|
||||
return bench_params
|
||||
|
||||
|
||||
# Reference implementations
|
||||
def unfused_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int, # Changed from list[int]
|
||||
):
|
||||
"""Unfused: SiLU+Mul then per-tensor quantize."""
|
||||
hidden = x.shape[-1] // 2
|
||||
gate, up = x.split(hidden, dim=-1)
|
||||
|
||||
# SiLU(gate) * up
|
||||
silu_out = F.silu(gate) * up
|
||||
|
||||
# Per-tensor quantize (no group_size used here)
|
||||
silu_out, _ = ops.scaled_fp8_quant(silu_out)
|
||||
|
||||
|
||||
def unfused_groupwise_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int, # Changed from list[int]
|
||||
):
|
||||
"""Unfused: SiLU+Mul then group-wise quantize."""
|
||||
hidden = x.shape[-1] // 2
|
||||
gate, up = x.split(hidden, dim=-1)
|
||||
|
||||
# SiLU(gate) * up
|
||||
silu_out = F.silu(gate) * up
|
||||
|
||||
# Group quantize - use group_size directly
|
||||
silu_out, _ = per_token_group_quant_fp8(
|
||||
silu_out, group_size=group_size, use_ue8m0=False
|
||||
)
|
||||
|
||||
|
||||
def fused_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int,
|
||||
):
|
||||
"""Fused: SiLU+Mul+Block Quantization in single kernel."""
|
||||
out, _ = ops.silu_and_mul_per_block_quant(
|
||||
x,
|
||||
group_size=group_size,
|
||||
quant_dtype=quant_dtype,
|
||||
is_scale_transposed=False,
|
||||
)
|
||||
|
||||
|
||||
# Bench functions
|
||||
def bench_fn(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
fn: Callable,
|
||||
description: str,
|
||||
) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"x": x,
|
||||
"quant_dtype": quant_dtype,
|
||||
"group_size": group_size,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(x, quant_dtype, group_size)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
|
||||
"""Run benchmarks for all implementations."""
|
||||
# Make inputs: [num_tokens, hidden_size * 2] for [gate || up]
|
||||
scale = 1 / params.hidden_size
|
||||
x = (
|
||||
torch.randn(
|
||||
params.num_tokens,
|
||||
params.hidden_size * 2,
|
||||
dtype=params.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
* scale
|
||||
)
|
||||
|
||||
timers = []
|
||||
|
||||
# Unfused per-tensor FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_fp8_impl,
|
||||
"unfused_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# Unfused group-wise FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_groupwise_fp8_impl,
|
||||
"unfused_groupwise_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# Fused group-wise FP8
|
||||
timers.append(
|
||||
bench_fn(
|
||||
x,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
fused_impl,
|
||||
"fused_groupwise_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_default_device("cuda")
|
||||
bench_params = get_bench_params()
|
||||
|
||||
print(f"Running {len(bench_params)} benchmark configurations...")
|
||||
print(
|
||||
f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)"
|
||||
)
|
||||
print()
|
||||
|
||||
timers = []
|
||||
for bp in tqdm(bench_params):
|
||||
result_timers = bench(bp, "silu-mul-block-quant", bp.description())
|
||||
timers.extend(result_timers)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("FINAL COMPARISON - ALL RESULTS")
|
||||
print("=" * 80)
|
||||
print_timers(timers)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,134 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
# Dimensions supported by the DSV3 specialized kernel
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
# Dimensions supported by the gpt-oss specialized kernel
|
||||
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
|
||||
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
|
||||
|
||||
|
||||
def get_batch_size_range(max_batch_size):
|
||||
return [2**x for x in range(14) if 2**x <= max_batch_size]
|
||||
|
||||
|
||||
def get_model_params(config):
|
||||
if config.architectures[0] in (
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
):
|
||||
num_experts = config.n_routed_experts
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] in ("GptOssForCausalLM",):
|
||||
num_experts = config.num_local_experts
|
||||
hidden_size = config.hidden_size
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {config.architectures}")
|
||||
return num_experts, hidden_size
|
||||
|
||||
|
||||
def get_benchmark(model, max_batch_size, trust_remote_code):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=get_batch_size_range(max_batch_size),
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"torch",
|
||||
"vllm",
|
||||
],
|
||||
line_names=["PyTorch", "vLLM"],
|
||||
styles=([("blue", "-"), ("red", "-")]),
|
||||
ylabel="TFLOPs",
|
||||
plot_name=f"{model} router gemm throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
config = get_config(model=model, trust_remote_code=trust_remote_code)
|
||||
num_experts, hidden_size = get_model_params(config)
|
||||
|
||||
mat_a = torch.randn(
|
||||
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
mat_b = torch.randn(
|
||||
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
bias = torch.randn(
|
||||
num_experts, dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
90
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
allow_dsv3_router_gemm = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
allow_gpt_oss_router_gemm = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
has_bias = False
|
||||
if allow_gpt_oss_router_gemm:
|
||||
has_bias = True
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
|
||||
def runner():
|
||||
if has_bias:
|
||||
F.linear(mat_a, mat_b, bias)
|
||||
else:
|
||||
F.linear(mat_a, mat_b)
|
||||
elif provider == "vllm":
|
||||
|
||||
def runner():
|
||||
if allow_dsv3_router_gemm:
|
||||
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
|
||||
elif allow_gpt_oss_router_gemm:
|
||||
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
|
||||
else:
|
||||
raise ValueError("Unsupported router gemm")
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
runner, quantiles=quantiles
|
||||
)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * batch_size * hidden_size * num_experts
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
|
||||
parser.add_argument("--max-batch-size", default=16, type=int)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the benchmark function
|
||||
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
162
benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py
Normal file
162
benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Benchmarks the fused Triton bilinear position-embedding kernel against
|
||||
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
|
||||
#
|
||||
# == Usage Examples ==
|
||||
#
|
||||
# Default benchmark:
|
||||
# python3 benchmark_vit_bilinear_pos_embed.py
|
||||
#
|
||||
# Custom parameters:
|
||||
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
|
||||
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.qwen3_vl import (
|
||||
pos_embed_interpolate_native,
|
||||
triton_pos_embed_interpolate,
|
||||
)
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
# (h, w) configurations to benchmark
|
||||
h_w_configs = [
|
||||
(16, 16),
|
||||
(32, 32),
|
||||
(48, 48),
|
||||
(64, 64),
|
||||
(128, 128),
|
||||
(32, 48),
|
||||
(60, 80),
|
||||
]
|
||||
|
||||
# Temporal dimensions
|
||||
t_range = [1]
|
||||
|
||||
configs = list(itertools.product(t_range, h_w_configs))
|
||||
|
||||
|
||||
def get_benchmark(
|
||||
num_grid_per_side: int,
|
||||
spatial_merge_size: int,
|
||||
hidden_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["t", "h_w"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["native", "triton"],
|
||||
line_names=["Native (PyTorch)", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=(
|
||||
f"vit-bilinear-pos-embed-"
|
||||
f"grid{num_grid_per_side}-"
|
||||
f"dim{hidden_dim}-"
|
||||
f"{dtype}"
|
||||
),
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(t, h_w, provider):
|
||||
h, w = h_w
|
||||
|
||||
torch.manual_seed(42)
|
||||
embed_weight = (
|
||||
torch.randn(
|
||||
num_grid_per_side * num_grid_per_side,
|
||||
hidden_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
* 0.25
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "native":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: pos_embed_interpolate_native(
|
||||
embed_weight,
|
||||
t,
|
||||
h,
|
||||
w,
|
||||
num_grid_per_side,
|
||||
spatial_merge_size,
|
||||
dtype,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
assert HAS_TRITON, "Triton not available"
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: triton_pos_embed_interpolate(
|
||||
embed_weight,
|
||||
t,
|
||||
h,
|
||||
w,
|
||||
num_grid_per_side,
|
||||
spatial_merge_size,
|
||||
dtype,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark bilinear position embedding interpolation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-grid-per-side",
|
||||
type=int,
|
||||
default=48,
|
||||
help="Position embedding grid size (default: 48 for Qwen3-VL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spatial-merge-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Spatial merge size (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=1152,
|
||||
help="Embedding hidden dimension (default: 1152 for Qwen3-VL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
choices=["cuda:0", "cuda:1"],
|
||||
default="cuda:0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./vit_pos_embed/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dtype = torch.bfloat16
|
||||
|
||||
bench = get_benchmark(
|
||||
args.num_grid_per_side,
|
||||
args.spatial_merge_size,
|
||||
args.hidden_dim,
|
||||
dtype,
|
||||
args.device,
|
||||
)
|
||||
bench.run(print_data=True, save_path=args.save_path)
|
||||
@@ -373,6 +373,7 @@ if (ENABLE_X86_ISA)
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
|
||||
|
||||
@@ -39,7 +39,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 29210221863736a08f71a866459e368ad1ac4a95
|
||||
GIT_TAG c0ec424fd8a546d0cbbf4bf050bbcfe837c55afb
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@@ -3,22 +3,33 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "../quantization/w8a8/fp8/common.cuh"
|
||||
#include "../dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
template <typename scalar_t, typename output_t, const uint NUM_THREADS,
|
||||
bool USE_FP8_OUTPUT>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
output_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size, const uint prefix_head_stride,
|
||||
const uint output_head_stride) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint output_head_stride, const uint prefix_num_tokens,
|
||||
const float* output_scale) {
|
||||
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
|
||||
// Outputs store pack_size elements of output_t, which is smaller for FP8.
|
||||
using input_pack_t = uint4;
|
||||
using output_pack_t =
|
||||
std::conditional_t<USE_FP8_OUTPUT,
|
||||
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
|
||||
uint4>;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
@@ -41,8 +52,45 @@ __global__ void merge_attn_states_kernel(
|
||||
head_idx * output_head_stride;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
|
||||
scalar_t* output_head_ptr = output + dst_head_offset;
|
||||
output_t* output_head_ptr = output + dst_head_offset;
|
||||
|
||||
// Pre-invert scale: multiplication is faster than division
|
||||
float fp8_scale_inv = 1.0f;
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
fp8_scale_inv = 1.0f / *output_scale;
|
||||
}
|
||||
|
||||
// If token_idx >= prefix_num_tokens, just copy from suffix
|
||||
if (token_idx >= prefix_num_tokens) {
|
||||
if (pack_offset < head_size) {
|
||||
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
const float val =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
o_out_pack[i] =
|
||||
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
|
||||
}
|
||||
}
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
output_lse[head_idx * num_tokens + token_idx] = s_lse;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// For tokens within prefix range, merge prefix and suffix
|
||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||
@@ -53,20 +101,34 @@ __global__ void merge_attn_states_kernel(
|
||||
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
|
||||
continuing the pipeline then yields NaN. Root cause: with chunked prefill
|
||||
a batch may be split into two chunks; if a request in that batch has no
|
||||
prefix hit, every LSE entry for that request’s position is -inf, and at
|
||||
prefix hit, every LSE entry for that request's position is -inf, and at
|
||||
this moment we merge cross-attention at first. For now we simply emit
|
||||
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
|
||||
this problem.
|
||||
*/
|
||||
if (std::isinf(max_lse)) {
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
p_out_pack;
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
// Convert prefix values to FP8 (since -inf means no data,
|
||||
// prefix_output is expected to be zeros)
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
const float val =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
o_out_pack[i] =
|
||||
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
|
||||
}
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
@@ -84,30 +146,43 @@ __global__ void merge_attn_states_kernel(
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
// Compute merged values in float32
|
||||
float o_out_f[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
o_out_pack;
|
||||
// Convert and store
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
|
||||
o_out_f[i], fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
output_pack_t o_out_pack;
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
|
||||
o_out_f[i]);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
|
||||
}
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
@@ -134,50 +209,73 @@ __global__ void merge_attn_states_kernel(
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
|
||||
USE_FP8_OUTPUT) \
|
||||
{ \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \
|
||||
USE_FP8_OUTPUT> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride); \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride, \
|
||||
prefix_num_tokens, output_scale_ptr); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
* @param prefill_tokens_with_context Number of prefill tokens with context
|
||||
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
|
||||
* is computed by merging prefix_output and suffix_output. For remaining tokens
|
||||
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
|
||||
* from suffix_output.
|
||||
* @param output_scale Optional scalar tensor for FP8 static quantization.
|
||||
* When provided, output must be FP8 dtype.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
void merge_attn_states_launcher(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint prefix_head_stride = prefix_output.stride(1);
|
||||
const uint output_head_stride = output.stride(1);
|
||||
// Thread mapping is based on input BF16 pack_size
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
|
||||
const uint prefix_num_tokens =
|
||||
prefill_tokens_with_context.has_value()
|
||||
? static_cast<uint>(prefill_tokens_with_context.value())
|
||||
: num_tokens;
|
||||
TORCH_CHECK(prefix_num_tokens <= num_tokens,
|
||||
"prefix_num_tokens must be <= num_tokens");
|
||||
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
}
|
||||
float* output_scale_ptr = nullptr;
|
||||
if (output_scale.has_value()) {
|
||||
output_scale_ptr = output_scale.value().data_ptr<float>();
|
||||
}
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
@@ -189,14 +287,22 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
if (output_scale.has_value()) {
|
||||
// FP8 output path - dispatch on output FP8 type
|
||||
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
|
||||
});
|
||||
} else {
|
||||
// Original BF16/FP16/FP32 output path
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
|
||||
prefix_lse, suffix_output, \
|
||||
suffix_lse); \
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>( \
|
||||
output, output_lse, prefix_output, prefix_lse, suffix_output, \
|
||||
suffix_lse, prefill_tokens_with_context, output_scale); \
|
||||
}
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
@@ -204,6 +310,21 @@ void merge_attn_states(torch::Tensor& output,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
const torch::Tensor& suffix_lse,
|
||||
std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
if (output_scale.has_value()) {
|
||||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
|
||||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
|
||||
"output must be FP8 when output_scale is provided, got: ",
|
||||
output.scalar_type());
|
||||
} else {
|
||||
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
|
||||
"output dtype (", output.scalar_type(),
|
||||
") must match prefix_output dtype (",
|
||||
prefix_output.scalar_type(), ") when output_scale is not set");
|
||||
}
|
||||
// Always dispatch on prefix_output (input) dtype
|
||||
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
|
||||
CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
int64_t block_size_in_bytes,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
const torch::Tensor& dst_ptrs,
|
||||
const torch::Tensor& sizes);
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#if defined(__gfx942__)
|
||||
@@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
}
|
||||
}
|
||||
|
||||
void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
const torch::Tensor& dst_ptrs,
|
||||
const torch::Tensor& sizes) {
|
||||
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
|
||||
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
|
||||
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
|
||||
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
|
||||
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
|
||||
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
|
||||
|
||||
const int64_t n = src_ptrs.size(0);
|
||||
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
|
||||
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
|
||||
|
||||
if (n == 0) return;
|
||||
|
||||
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
|
||||
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
|
||||
const int64_t* size_data = sizes.data_ptr<int64_t>();
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
|
||||
// driver call, amortizing per-copy submission overhead.
|
||||
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
|
||||
// so we reinterpret_cast the tensor data directly to avoid copies.
|
||||
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
|
||||
static_assert(sizeof(size_t) == sizeof(int64_t));
|
||||
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
|
||||
CUmemcpyAttributes attr = {};
|
||||
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
|
||||
size_t attrs_idx = 0;
|
||||
size_t fail_idx = 0;
|
||||
CUresult result = cuMemcpyBatchAsync(
|
||||
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
|
||||
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
|
||||
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
|
||||
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
|
||||
static_cast<CUstream>(stream));
|
||||
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
|
||||
fail_idx, " with error ", result);
|
||||
#else
|
||||
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
|
||||
// cudaMemcpyDefault lets the driver infer direction from pointer types.
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
|
||||
reinterpret_cast<void*>(src_data[i]),
|
||||
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
|
||||
stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Grid: (num_layers, num_pairs)
|
||||
|
||||
@@ -30,13 +30,15 @@
|
||||
}()
|
||||
|
||||
namespace {
|
||||
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
|
||||
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
|
||||
|
||||
FusedMOEAct get_act_type(const std::string& act) {
|
||||
if (act == "silu") {
|
||||
return FusedMOEAct::SiluAndMul;
|
||||
} else if (act == "swigluoai") {
|
||||
return FusedMOEAct::SwigluOAIAndMul;
|
||||
} else if (act == "gelu") {
|
||||
return FusedMOEAct::GeluAndMul;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid act type: " + act);
|
||||
}
|
||||
@@ -104,6 +106,43 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
const int32_t m_size, const int32_t n_size,
|
||||
const int32_t input_stride, const int32_t output_stride) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
const int32_t dim = n_size / 2;
|
||||
float* __restrict__ gate = input;
|
||||
float* __restrict__ up = input + dim;
|
||||
vec_op::FP32Vec16 one_vec(1.0);
|
||||
vec_op::FP32Vec16 w1_vec(M_SQRT1_2);
|
||||
vec_op::FP32Vec16 w2_vec(0.5);
|
||||
alignas(64) float temp[16];
|
||||
|
||||
DEFINE_FAST_EXP
|
||||
|
||||
for (int32_t m = 0; m < m_size; ++m) {
|
||||
for (int32_t n = 0; n < dim; n += 16) {
|
||||
vec_op::FP32Vec16 gate_vec(gate + n);
|
||||
vec_op::FP32Vec16 up_vec(up + n);
|
||||
auto er_input_vec = gate_vec * w1_vec;
|
||||
|
||||
er_input_vec.save(temp);
|
||||
for (int32_t i = 0; i < 16; ++i) {
|
||||
temp[i] = std::erf(temp[i]);
|
||||
}
|
||||
vec_op::FP32Vec16 er_vec(temp);
|
||||
auto gelu = gate_vec * w2_vec * (one_vec + er_vec);
|
||||
auto gated_output_fp32 = up_vec * gelu;
|
||||
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
|
||||
gated_output.save(output + n);
|
||||
}
|
||||
gate += input_stride;
|
||||
up += input_stride;
|
||||
output += output_stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
float* __restrict__ input,
|
||||
@@ -118,6 +157,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
case FusedMOEAct::SiluAndMul:
|
||||
silu_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
case FusedMOEAct::GeluAndMul:
|
||||
gelu_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported act type.");
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ Generate CPU attention dispatch switch cases and kernel instantiations.
|
||||
import os
|
||||
|
||||
# Head dimensions divisible by 32 (support all ISAs)
|
||||
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256, 512]
|
||||
|
||||
# Head dimensions divisible by 16 but not 32 (VEC16 only)
|
||||
HEAD_DIMS_16 = [80, 112]
|
||||
|
||||
@@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int get_thread_num() {
|
||||
#if defined(_OPENMP)
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
|
||||
@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
|
||||
template <typename T> inline bool can_use_brgemm(int M);
|
||||
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
|
||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
|
||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<uint8_t>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
|
||||
|
||||
@@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
// pack weight to vnni format
|
||||
inline int64_t get_4bit_block_k_size(int64_t group_size) {
|
||||
return group_size > 128 ? 128 : group_size;
|
||||
}
|
||||
|
||||
// pack weight into vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// pack weight to vnni format for int4 (adapted from sglang)
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
convert_weight_packed_scale_zp(at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
@@ -233,6 +241,31 @@ void tinygemm_kernel(
|
||||
int64_t strideBs,
|
||||
bool brg);
|
||||
|
||||
// int4 scaled GEMM (adapted from sglang)
|
||||
at::Tensor int4_scaled_mm_cpu(
|
||||
at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, at::Tensor& w_scales, std::optional<at::Tensor> bias);
|
||||
|
||||
// int4 tinygemm kernel interface(adapted from sglang)
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
scalar_t* C,
|
||||
float* C_temp,
|
||||
const uint8_t* A,
|
||||
const float* scales_a,
|
||||
const int32_t* qzeros_a,
|
||||
const uint8_t* B,
|
||||
const float* scales_b,
|
||||
const int8_t* qzeros_b,
|
||||
const int32_t* compensation,
|
||||
int8_t* dqB_tmp,
|
||||
int64_t M,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldc_f,
|
||||
int64_t ldc_s,
|
||||
bool store_out,
|
||||
bool use_brgemm);
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
inline void print_16x32i(const __m512i x) {
|
||||
int32_t a[16];
|
||||
|
||||
755
csrc/cpu/sgl-kernels/gemm_int4.cpp
Normal file
755
csrc/cpu/sgl-kernels/gemm_int4.cpp
Normal file
@@ -0,0 +1,755 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Adapted from sgl-project/sglang
|
||||
// https://github.com/sgl-project/sglang/pull/8226
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
#define BLOCK_N block_size_n()
|
||||
#define BLOCK_M 128
|
||||
|
||||
template <bool sym_quant_act>
|
||||
struct ActDtype;
|
||||
template <>
|
||||
struct ActDtype<true> {
|
||||
using type = int8_t;
|
||||
};
|
||||
template <>
|
||||
struct ActDtype<false> {
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
struct alignas(32) m256i_wrapper {
|
||||
__m256i data;
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline std::array<m256i_wrapper, 2> load_zps_4vnni(
|
||||
const int8_t* __restrict__ zps) {
|
||||
__m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps));
|
||||
__m256i vzps_high =
|
||||
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps + 8));
|
||||
__m256i shuffle_mask =
|
||||
_mm256_set_epi8(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3,
|
||||
3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
|
||||
vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask);
|
||||
vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask);
|
||||
m256i_wrapper vzps_low_wp, vzps_high_wp;
|
||||
vzps_low_wp.data = vzps_low;
|
||||
vzps_high_wp.data = vzps_high;
|
||||
return {vzps_low_wp, vzps_high_wp};
|
||||
}
|
||||
|
||||
inline std::array<m256i_wrapper, 2> load_uint4_as_int8(
|
||||
const uint8_t* __restrict__ qB) {
|
||||
__m256i packed = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(qB));
|
||||
const __m256i low_mask = _mm256_set1_epi8(0x0f);
|
||||
__m256i high = _mm256_srli_epi16(packed, 4);
|
||||
high = _mm256_and_si256(high, low_mask);
|
||||
__m256i low = _mm256_and_si256(packed, low_mask);
|
||||
m256i_wrapper low_wp, high_wp;
|
||||
low_wp.data = low;
|
||||
high_wp.data = high;
|
||||
return {low_wp, high_wp};
|
||||
}
|
||||
|
||||
template <int N, int ldb>
|
||||
void _dequant_weight_zp_only(const uint8_t* __restrict__ B, int8_t* dqB,
|
||||
const int8_t* __restrict__ qzeros, int64_t K) {
|
||||
#pragma GCC unroll 2
|
||||
for (int n = 0; n < N; n += 16) {
|
||||
auto [zps_low_wp, zps_high_wp] = load_zps_4vnni(&qzeros[n]);
|
||||
auto zps_low = zps_low_wp.data;
|
||||
auto zps_high = zps_high_wp.data;
|
||||
for (int k = 0; k < K; k += 4) {
|
||||
auto [vb_low_wp, vb_high_wp] =
|
||||
load_uint4_as_int8(B + ldb * k + n / 2 * 4);
|
||||
auto vb_low = vb_low_wp.data;
|
||||
auto vb_high = vb_high_wp.data;
|
||||
vb_high = _mm256_sub_epi8(vb_high, zps_high);
|
||||
vb_low = _mm256_sub_epi8(vb_low, zps_low);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4),
|
||||
vb_low);
|
||||
_mm256_storeu_si256(
|
||||
reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool sym_quant_act, int N, bool accum>
|
||||
void _dequant_and_store(float* __restrict__ output,
|
||||
const int32_t* __restrict__ input,
|
||||
const float* __restrict__ scale_a,
|
||||
const int32_t* __restrict__ zp_a,
|
||||
const float* __restrict__ scale_b,
|
||||
const int32_t* __restrict__ comp_b, int M, int ldi,
|
||||
int ldo, int ldsa = 1) {
|
||||
for (int m = 0; m < M; ++m) {
|
||||
float a_scale = *(scale_a + m * ldsa);
|
||||
__m512 va_scale = _mm512_set1_ps(a_scale);
|
||||
int32_t a_zp;
|
||||
__m512i va_zp;
|
||||
if constexpr (!sym_quant_act) {
|
||||
a_zp = *(zp_a + m * ldsa);
|
||||
va_zp = _mm512_set1_epi32(a_zp);
|
||||
}
|
||||
int n = 0;
|
||||
#pragma GCC unroll 2
|
||||
for (; n < N; n += 16) {
|
||||
__m512i vc = _mm512_loadu_si512(input + m * ldi + n);
|
||||
if constexpr (!sym_quant_act) {
|
||||
__m512i vb_comp = _mm512_loadu_si512(comp_b + n);
|
||||
vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp));
|
||||
}
|
||||
__m512 vc_f = _mm512_cvtepi32_ps(vc);
|
||||
__m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale);
|
||||
__m512 vb_s = _mm512_loadu_ps(scale_b + n);
|
||||
vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s);
|
||||
if constexpr (accum) {
|
||||
__m512 vo = _mm512_loadu_ps(output + m * ldo + n);
|
||||
_mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul));
|
||||
} else {
|
||||
_mm512_storeu_ps(output + m * ldo + n, vc_f_mul);
|
||||
}
|
||||
}
|
||||
for (; n < N; ++n) {
|
||||
float dq_val;
|
||||
if constexpr (sym_quant_act) {
|
||||
dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n];
|
||||
} else {
|
||||
dq_val = (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale *
|
||||
scale_b[n];
|
||||
}
|
||||
if constexpr (accum) {
|
||||
output[m * ldo + n] += dq_val;
|
||||
} else {
|
||||
output[m * ldo + n] = dq_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
template <int N, int ldb>
|
||||
void _dequant_weight_zp_only(const uint8_t* B, int8_t* dqB,
|
||||
const int8_t* qzeros, int64_t K) {
|
||||
for (int k = 0; k < K; ++k) {
|
||||
for (int n = 0; n < N / 2; ++n) {
|
||||
int32_t b = (int32_t)B[k * ldb + n];
|
||||
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n];
|
||||
dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline __m512i combine_m256i(__m256i a, __m256i b) {
|
||||
__m512i c = _mm512_castsi256_si512(a);
|
||||
return _mm512_inserti64x4(c, b, 1);
|
||||
}
|
||||
|
||||
inline __m512i combine_m256i(std::array<m256i_wrapper, 2> two_256) {
|
||||
return combine_m256i(two_256[0].data, two_256[1].data);
|
||||
}
|
||||
|
||||
static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) {
|
||||
__m512i zero = _mm512_setzero_si512();
|
||||
__mmask64 blt0 = _mm512_movepi8_mask(b);
|
||||
return _mm512_mask_sub_epi8(a, blt0, zero, a);
|
||||
}
|
||||
|
||||
template <bool sym_quant_act, int M, int N, int ldb>
|
||||
void _dequant_gemm_accum_small_M(float* __restrict__ C, const uint8_t* A,
|
||||
const float* scales_a, const int32_t* qzeros_a,
|
||||
const uint8_t* B, const float* scales_b,
|
||||
const int8_t* qzeros_b, int64_t K, int64_t lda,
|
||||
int64_t ldc) {
|
||||
constexpr int COLS = N / 16;
|
||||
__m512i ones = _mm512_set1_epi8(1);
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[M * COLS];
|
||||
__m512 vscales[COLS];
|
||||
__m512i vzps[COLS];
|
||||
__m512i vcompensate[COLS];
|
||||
|
||||
Unroll<COLS>{}([&](auto i) {
|
||||
vscales[i] = _mm512_loadu_ps(scales_b + i * 16);
|
||||
vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16));
|
||||
if constexpr (!sym_quant_act) {
|
||||
vcompensate[i] = _mm512_setzero_epi32();
|
||||
}
|
||||
});
|
||||
Unroll<M * COLS>{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); });
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr const int row = i / COLS;
|
||||
constexpr const int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k));
|
||||
}
|
||||
|
||||
if constexpr (row == 0) {
|
||||
int B_offset = k * ldb + col * 16 * 2;
|
||||
vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset));
|
||||
vb[col] = _mm512_sub_epi8(vb[col], vzps[col]);
|
||||
if constexpr (!sym_quant_act) {
|
||||
vcompensate[col] = _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]);
|
||||
}
|
||||
_mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0);
|
||||
}
|
||||
if constexpr (sym_quant_act) {
|
||||
auto vsb = _mm512_sign_epi8(vb[col], va);
|
||||
auto vabsa = _mm512_sign_epi8(va, va);
|
||||
vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb);
|
||||
} else {
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr const int unroll = 4;
|
||||
int k = 0;
|
||||
for (; k < K / 4 / unroll; k++) {
|
||||
Unroll<unroll>{}(
|
||||
[&](auto i) { Unroll<M * COLS>{}(compute, 4 * (k * unroll + i)); });
|
||||
}
|
||||
k *= 4 * unroll;
|
||||
for (; k < K; k += 4) {
|
||||
Unroll<M * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto store = [&](auto i) {
|
||||
constexpr const int row = i / COLS;
|
||||
constexpr const int col = i % COLS;
|
||||
__m512 vc_float;
|
||||
if constexpr (!sym_quant_act) {
|
||||
vc[i] = _mm512_sub_epi32(
|
||||
vc[i], _mm512_mullo_epi32(vcompensate[col],
|
||||
_mm512_set1_epi32(*(qzeros_a + row))));
|
||||
}
|
||||
vc_float = _mm512_cvtepi32_ps(vc[i]);
|
||||
vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row)));
|
||||
|
||||
vc_float = _mm512_mul_ps(vc_float, vscales[col]);
|
||||
auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16);
|
||||
vc_float = _mm512_add_ps(vc_float, vc_old);
|
||||
_mm512_storeu_ps(C + row * ldc + col * 16, vc_float);
|
||||
};
|
||||
Unroll<M * COLS>{}(store);
|
||||
}
|
||||
|
||||
#define CALL_DEQUANT_GEMM_ACCUM_SMALL_M(M) \
|
||||
_dequant_gemm_accum_small_M<sym_quant_act, M, N, ldb>( \
|
||||
C, A, scales_a, qzeros_a, B, scales_b, qzeros_b, K, lda, ldc);
|
||||
#endif
|
||||
|
||||
template <bool sym_quant_act, int N, int ldb>
|
||||
void _dequant_gemm_accum(float* C, const uint8_t* A, const float* scales_a,
|
||||
const int32_t* qzeros_a, const uint8_t* B,
|
||||
const float* scales_b, const int8_t* qzeros_b,
|
||||
const int32_t* compensation, int8_t* dqB, int64_t M,
|
||||
int64_t K, int64_t lda, int64_t ldc, bool use_brgemm) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
if (!use_brgemm) {
|
||||
switch (M) {
|
||||
case 1:
|
||||
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(1);
|
||||
break;
|
||||
case 2:
|
||||
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(2);
|
||||
break;
|
||||
case 3:
|
||||
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(3);
|
||||
break;
|
||||
case 4:
|
||||
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(4);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "tinygemm_kernel: unexpected M for AVX path!");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
_dequant_weight_zp_only<N, ldb>(B, dqB, qzeros_b, K);
|
||||
using Tin = typename ActDtype<sym_quant_act>::type;
|
||||
Tin* A_ptr = (Tin*)A;
|
||||
if (use_brgemm) {
|
||||
int32_t C_i32[M * N];
|
||||
at::native::cpublas::brgemm(M, N, K, lda, N /*ldb*/, N /*ldc*/,
|
||||
false /* add_C */, A_ptr, dqB, C_i32,
|
||||
true /* is_vnni */);
|
||||
_mm_prefetch(B + N * K / 2, _MM_HINT_T0);
|
||||
_mm_prefetch(A + K, _MM_HINT_T0);
|
||||
_dequant_and_store<sym_quant_act, N, true>(C, C_i32, scales_a, qzeros_a,
|
||||
scales_b, compensation, M,
|
||||
N /*ldi*/, ldc, 1 /*ldsa*/);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
TORCH_CHECK(false, "tinygemm_kernel: scalar path not implemented!");
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) {
|
||||
if (bias_ptr) {
|
||||
for (int i = 0; i < m; ++i) {
|
||||
int j = 0;
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 bias_vec = _mm512_loadu_ps(bias_ptr + j);
|
||||
_mm512_storeu_ps(y_buf + i * N + j, bias_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
y_buf[i * N + j] = bias_ptr[j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < m; ++i) {
|
||||
int j = 0;
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 zero_vec = _mm512_setzero_ps();
|
||||
_mm512_storeu_ps(y_buf + i * N + j, zero_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
y_buf[i * N + j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename out_dtype>
|
||||
inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m,
|
||||
int64_t lda) {
|
||||
for (int i = 0; i < m; ++i) {
|
||||
int j = 0;
|
||||
if constexpr (std::is_same<out_dtype, float>::value) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||
_mm512_storeu_ps(c_ptr + i * lda + j, y_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
c_ptr[i * lda + j] = y_buf[i * N + j];
|
||||
}
|
||||
} else if constexpr (std::is_same<out_dtype, at::BFloat16>::value) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||
__m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
|
||||
y_bf16_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]);
|
||||
}
|
||||
} else if constexpr (std::is_same<out_dtype, at::Half>::value) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||
__m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
|
||||
y_fp16_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output dtype");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fill_val_stub(int32_t* __restrict__ output, int32_t value, int64_t size) {
|
||||
using iVec = at::vec::Vectorized<int32_t>;
|
||||
constexpr int VecSize = iVec::size();
|
||||
const iVec fill_val_vec = iVec(value);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - VecSize; d += VecSize) {
|
||||
fill_val_vec.store(output + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
output[d] = value;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool sym_quant_act, typename act_dtype, typename out_dtype>
|
||||
void _da8w4_linear_impl(
|
||||
act_dtype* __restrict__ input, const float* __restrict__ input_scales,
|
||||
const int32_t* __restrict__ input_qzeros,
|
||||
const uint8_t* __restrict__ weight, const float* __restrict__ weight_scales,
|
||||
const int8_t* __restrict__ weight_qzeros, const float* __restrict__ bias,
|
||||
out_dtype* __restrict__ output, float* __restrict__ output_temp,
|
||||
int8_t* __restrict__ dequant_weight_temp, int64_t M, int64_t N, int64_t K,
|
||||
int64_t num_groups) {
|
||||
const bool use_brgemm = can_use_brgemm<act_dtype>(M);
|
||||
int64_t block_m = [&]() -> long {
|
||||
if (M <= 48) {
|
||||
return M;
|
||||
} else if (M < 64) {
|
||||
return 32;
|
||||
} else if (M < 96) {
|
||||
return 64;
|
||||
} else {
|
||||
return 128;
|
||||
}
|
||||
}();
|
||||
int64_t Mc = div_up(M, block_m);
|
||||
bool parallel_on_M = M > 128;
|
||||
int64_t Nc = N / BLOCK_N;
|
||||
int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc;
|
||||
int64_t group_size = div_up(K, num_groups);
|
||||
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||
int64_t Kc = K / _block_k;
|
||||
int64_t block_per_group = group_size / _block_k;
|
||||
|
||||
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
|
||||
int tid = get_thread_num();
|
||||
float* C_tmp = output_temp + tid * block_m * BLOCK_N;
|
||||
int8_t* dqB_tmp = dequant_weight_temp + tid * _block_k * BLOCK_N;
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
int64_t mc = parallel_on_M ? i / Nc : 0;
|
||||
int64_t nc = parallel_on_M ? i % Nc : i;
|
||||
int64_t mc_end = parallel_on_M ? mc + 1 : Mc;
|
||||
|
||||
for (int mci = mc; mci < mc_end; ++mci) {
|
||||
int64_t m_size =
|
||||
mci * block_m + block_m > M ? M - mci * block_m : block_m;
|
||||
auto bias_data = bias ? bias + nc * BLOCK_N : nullptr;
|
||||
copy_bias<BLOCK_N>(bias_data, C_tmp, m_size);
|
||||
for (int kci = 0; kci < Kc; ++kci) {
|
||||
int32_t* compensation_ptr =
|
||||
sym_quant_act
|
||||
? nullptr
|
||||
: (int32_t*)(void*)(weight +
|
||||
(nc * Kc + kci) *
|
||||
(BLOCK_N *
|
||||
(_block_k / 2 + sizeof(int32_t))) +
|
||||
_block_k * BLOCK_N / 2);
|
||||
_dequant_gemm_accum<sym_quant_act, BLOCK_N, BLOCK_N / 2>(
|
||||
/*C*/ C_tmp,
|
||||
/*A*/ (uint8_t*)input + mci * block_m * K + kci * _block_k,
|
||||
/*scales_a*/ input_scales + mci * block_m,
|
||||
/*qzeros_a*/ input_qzeros + mci * block_m,
|
||||
/*B*/ weight + (nc * Kc + kci) *
|
||||
(BLOCK_N * (_block_k / 2 + sizeof(int32_t))),
|
||||
/*scales_b*/ weight_scales + nc * BLOCK_N * num_groups +
|
||||
kci / block_per_group * BLOCK_N,
|
||||
/*qzeros_b*/ weight_qzeros + nc * BLOCK_N * num_groups +
|
||||
kci / block_per_group * BLOCK_N,
|
||||
/*Bcomp*/ compensation_ptr,
|
||||
/*dqB_tmp*/ dqB_tmp,
|
||||
/*M*/ m_size,
|
||||
/*K*/ _block_k,
|
||||
/*lda*/ K,
|
||||
/*ldc*/ BLOCK_N,
|
||||
/*use_brgemm*/ use_brgemm);
|
||||
}
|
||||
store_out<BLOCK_N>(C_tmp, output + mci * block_m * N + nc * BLOCK_N,
|
||||
m_size, N /*lda*/);
|
||||
}
|
||||
}
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
convert_int4_weight_packed_with_compensation(const at::Tensor& weight,
|
||||
const at::Tensor& scales,
|
||||
const at::Tensor& qzeros) {
|
||||
TORCH_CHECK(weight.dim() == 2,
|
||||
"DA8W4 CPU: Weight should be a 2D tensor for packing");
|
||||
TORCH_CHECK(
|
||||
weight.size(1) % 2 == 0,
|
||||
"DA8W4 CPU: Weight should have even number of columns for packing");
|
||||
|
||||
auto new_scales = scales;
|
||||
auto new_qzeros = qzeros;
|
||||
if (new_scales.dim() == 1) {
|
||||
new_scales.unsqueeze_(1);
|
||||
}
|
||||
new_scales = new_scales.to(at::kFloat);
|
||||
if (new_qzeros.dim() == 1) {
|
||||
new_qzeros.unsqueeze_(1);
|
||||
}
|
||||
new_qzeros = new_qzeros.to(at::kChar);
|
||||
int64_t N = weight.size(0);
|
||||
int64_t K = weight.size(1);
|
||||
int64_t G = scales.size(1);
|
||||
int64_t group_size = K / G;
|
||||
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||
constexpr int block_n = block_size_n();
|
||||
int64_t Nc = N / block_n;
|
||||
int64_t Kc = K / _block_k;
|
||||
|
||||
auto weight_view = weight.view({Nc, block_n, Kc, _block_k});
|
||||
at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous();
|
||||
at::Tensor blocked_weight;
|
||||
at::Tensor blocked_scales =
|
||||
new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
|
||||
at::Tensor blocked_qzeros =
|
||||
new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
|
||||
auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) -
|
||||
new_qzeros.view({Nc, block_n, G, -1});
|
||||
weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, _block_k});
|
||||
at::Tensor compensation = weight_sub_qzero.sum(-1);
|
||||
compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt);
|
||||
int64_t buffer_size_nbytes =
|
||||
_block_k * block_n / 2 + block_n * sizeof(int32_t);
|
||||
blocked_weight = at::empty({Nc, Kc, buffer_size_nbytes}, weight.options());
|
||||
|
||||
auto weight_ptr = weight_reordered.data_ptr<uint8_t>();
|
||||
auto compensation_ptr = compensation.data_ptr<int32_t>();
|
||||
auto blocked_weight_ptr = blocked_weight.data_ptr<uint8_t>();
|
||||
int64_t num_blocks = Nc * Kc;
|
||||
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
auto in_ptr = weight_ptr + i * _block_k * block_n;
|
||||
auto out_ptr =
|
||||
blocked_weight_ptr + i * block_n * (_block_k / 2 + sizeof(int32_t));
|
||||
int32_t* comp_in_prt = compensation_ptr + i * block_n;
|
||||
int32_t* comp_out_prt =
|
||||
(int32_t*)(void*)(blocked_weight_ptr +
|
||||
i * block_n * (_block_k / 2 + sizeof(int32_t)) +
|
||||
_block_k * block_n / 2);
|
||||
constexpr int n_group_size = 8;
|
||||
constexpr int vnni_size = 4;
|
||||
constexpr int n_group = block_n / n_group_size;
|
||||
for (int nb = 0; nb < n_group; nb += 2) {
|
||||
for (int k = 0; k < _block_k; k += vnni_size) {
|
||||
for (int ni = 0; ni < n_group_size; ++ni) {
|
||||
for (int ki = 0; ki < vnni_size; ++ki) {
|
||||
int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n;
|
||||
int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n;
|
||||
int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size +
|
||||
k * block_n / 2 + ki;
|
||||
uint8_t src_1 = *(in_ptr + src_idx_1);
|
||||
uint8_t src_2 = *(in_ptr + src_idx_2);
|
||||
uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4);
|
||||
*(out_ptr + dst_idx) = dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int nb = 0; nb < block_n; nb++) {
|
||||
*(comp_out_prt + nb) = *(comp_in_prt + nb);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales),
|
||||
std::move(blocked_qzeros));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> autoawq_to_int4pack(at::Tensor qweight,
|
||||
at::Tensor qzeros) {
|
||||
auto bitshifts = at::tensor({0, 4, 1, 5, 2, 6, 3, 7}, at::kInt) * 4;
|
||||
auto qweight_unsq = qweight.unsqueeze(-1);
|
||||
auto unpacked = at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF;
|
||||
auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte);
|
||||
|
||||
auto qzeros_unsq = qzeros.unsqueeze(-1);
|
||||
auto qzeros_unpacked = at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF;
|
||||
auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte);
|
||||
|
||||
return std::make_tuple(qweight_final, qzeros_final);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
|
||||
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales) {
|
||||
auto res = autoawq_to_int4pack(qweight, qzeros);
|
||||
auto _qweight = std::get<0>(res);
|
||||
auto _qzeros = std::get<1>(res);
|
||||
auto _scales = scales;
|
||||
_qzeros = _qzeros.transpose(-2, -1).contiguous();
|
||||
_scales = _scales.transpose(-2, -1).contiguous();
|
||||
if (_qweight.dim() == 3) {
|
||||
int64_t E = _qweight.size(0);
|
||||
int64_t K = _qweight.size(2);
|
||||
int64_t G = _scales.size(2);
|
||||
int64_t group_size = K / G;
|
||||
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||
int64_t block_n = block_size_n();
|
||||
int64_t Nc = _qweight.size(1) / block_n;
|
||||
int64_t Kc = K / _block_k;
|
||||
int64_t buffer_size_nbytes =
|
||||
_block_k * block_n / 2 + block_n * sizeof(int32_t);
|
||||
auto blocked_weight =
|
||||
at::empty({E, Nc, Kc, buffer_size_nbytes}, _qweight.options());
|
||||
auto blocked_scales =
|
||||
at::empty({E, Nc, G, block_n}, _scales.options()).to(at::kFloat);
|
||||
auto blocked_qzeros =
|
||||
at::empty({E, Nc, G, block_n}, _qzeros.options()).to(at::kChar);
|
||||
for (int i = 0; i < _qweight.size(0); i++) {
|
||||
auto res_ = convert_int4_weight_packed_with_compensation(
|
||||
_qweight[i], _scales[i], _qzeros[i]);
|
||||
blocked_weight[i] = std::get<0>(res_);
|
||||
blocked_scales[i] = std::get<1>(res_);
|
||||
blocked_qzeros[i] = std::get<2>(res_);
|
||||
}
|
||||
_qweight = blocked_weight;
|
||||
_scales = blocked_scales;
|
||||
_qzeros = blocked_qzeros;
|
||||
} else {
|
||||
auto res_ = convert_int4_weight_packed_with_compensation(_qweight, _scales,
|
||||
_qzeros);
|
||||
_qweight = std::get<0>(res_);
|
||||
_scales = std::get<1>(res_);
|
||||
_qzeros = std::get<2>(res_);
|
||||
}
|
||||
|
||||
return std::make_tuple(_qweight, _qzeros, _scales);
|
||||
}
|
||||
|
||||
at::Tensor int4_scaled_mm_cpu_with_quant(const at::Tensor& input,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& weight_scales,
|
||||
const at::Tensor& weight_qzeros,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType output_dtype) {
|
||||
RECORD_FUNCTION("vllm::int4_scaled_mm_cpu_with_quant",
|
||||
std::vector<c10::IValue>({input, weight}));
|
||||
|
||||
int64_t M_a = input.size(0);
|
||||
int64_t K_a = input.size(1);
|
||||
int64_t lda = input.stride(0);
|
||||
|
||||
const auto st = input.scalar_type();
|
||||
TORCH_CHECK(
|
||||
st == at::kBFloat16 || st == at::kHalf,
|
||||
"int4_scaled_mm_cpu_with_quant: expect A to be bfloat16 or half.");
|
||||
|
||||
constexpr bool sym_quant_act = false;
|
||||
using Tin = typename ActDtype<sym_quant_act>::type;
|
||||
int64_t act_buffer_size =
|
||||
M_a * K_a + M_a * sizeof(float) + M_a * sizeof(int32_t);
|
||||
auto act_buffer =
|
||||
at::empty({act_buffer_size}, input.options().dtype(at::kByte));
|
||||
auto Aq_data = act_buffer.data_ptr<uint8_t>();
|
||||
auto As_data = reinterpret_cast<float*>(Aq_data + M_a * K_a);
|
||||
auto Azp_data = reinterpret_cast<int32_t*>(As_data + M_a);
|
||||
fill_val_stub(Azp_data, 128, M_a);
|
||||
|
||||
auto out_sizes = input.sizes().vec();
|
||||
int64_t N = weight_scales.size(0) * weight_scales.size(-1);
|
||||
out_sizes.back() = N;
|
||||
auto output = at::empty(out_sizes, input.options());
|
||||
int64_t Nc = weight.size(0);
|
||||
int64_t Kc = weight.size(1);
|
||||
int64_t _block_k = K_a / Kc;
|
||||
TORCH_CHECK(N == Nc * BLOCK_N, "DA8W4: weight and input shapes mismatch");
|
||||
int64_t num_groups = weight_scales.size(1);
|
||||
|
||||
const uint8_t* b_ptr = weight.data_ptr<uint8_t>();
|
||||
const float* b_scales_ptr = weight_scales.data_ptr<float>();
|
||||
const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr<int8_t>();
|
||||
const float* bias_ptr =
|
||||
bias.has_value() ? bias.value().data_ptr<float>() : nullptr;
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t temp_buffer_size = num_threads * BLOCK_M * BLOCK_N * sizeof(float) +
|
||||
num_threads * _block_k * BLOCK_N;
|
||||
auto c_temp_buffer =
|
||||
at::empty({temp_buffer_size}, input.options().dtype(at::kChar));
|
||||
float* c_temp_ptr = (float*)((void*)(c_temp_buffer.data_ptr<int8_t>()));
|
||||
int8_t* dqB_temp_ptr =
|
||||
(int8_t*)((void*)(c_temp_ptr + num_threads * BLOCK_M * BLOCK_N));
|
||||
|
||||
#define LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act) \
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2( \
|
||||
at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, \
|
||||
"int4_scaled_mm_cpu", [&] { \
|
||||
const scalar_t* __restrict__ A_data = input.data_ptr<scalar_t>(); \
|
||||
scalar_t* __restrict__ c_ptr = output.data_ptr<scalar_t>(); \
|
||||
at::parallel_for(0, M_a, 0, [&](int64_t begin, int64_t end) { \
|
||||
for (int64_t m = begin; m < end; ++m) { \
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * K_a, As_data[m], \
|
||||
A_data + m * lda, K_a); \
|
||||
} \
|
||||
}); \
|
||||
_da8w4_linear_impl<sym_quant_act, Tin, scalar_t>( \
|
||||
Aq_data, As_data, Azp_data, b_ptr, b_scales_ptr, b_qzeros_ptr, \
|
||||
bias_ptr, c_ptr, c_temp_ptr, dqB_temp_ptr, M_a, N, K_a, \
|
||||
num_groups); \
|
||||
});
|
||||
|
||||
LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out,
|
||||
const float* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
fVec x0 = fVec::loadu(input + d);
|
||||
fVec x1 = fVec::loadu(input + d + fVec::size());
|
||||
Vec res = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
res.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(scalar_t* C, float* C_temp, const uint8_t* A,
|
||||
const float* scales_a, const int32_t* qzeros_a,
|
||||
const uint8_t* B, const float* scales_b,
|
||||
const int8_t* qzeros_b, const int32_t* compensation,
|
||||
int8_t* dqB_tmp, int64_t M, int64_t K, int64_t lda,
|
||||
int64_t ldc_f, int64_t ldc_s, bool store_out,
|
||||
bool use_brgemm) {
|
||||
_dequant_gemm_accum<false, BLOCK_N, BLOCK_N / 2>(
|
||||
C_temp, A, scales_a, qzeros_a, B, scales_b, qzeros_b, compensation,
|
||||
dqB_tmp, M, K, lda, ldc_f, use_brgemm);
|
||||
if (store_out) {
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
copy_stub<scalar_t>(C + m * ldc_s, C_temp + m * ldc_f, BLOCK_N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
TYPE * C, float* C_temp, const uint8_t* A, const float* scales_a, \
|
||||
const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, \
|
||||
const int8_t* qzeros_b, const int32_t* compensation, int8_t* dqB_tmp, \
|
||||
int64_t M, int64_t K, int64_t lda, int64_t ldc_f, int64_t ldc_s, \
|
||||
bool store_out, bool use_brgemm)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
|
||||
at::Tensor& w_scales,
|
||||
std::optional<at::Tensor> bias) {
|
||||
return int4_scaled_mm_cpu_with_quant(x, w, w_scales, w_zeros, bias,
|
||||
x.scalar_type());
|
||||
}
|
||||
@@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni);
|
||||
|
||||
// Adapted from sglang: INT4 W4A8 kernels
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
|
||||
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
|
||||
|
||||
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
|
||||
at::Tensor& w_scales,
|
||||
std::optional<at::Tensor> bias);
|
||||
|
||||
torch::Tensor get_scheduler_metadata(
|
||||
const int64_t num_req, const int64_t num_heads_q,
|
||||
const int64_t num_heads_kv, const int64_t head_dim,
|
||||
@@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
|
||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||
&int8_scaled_mm_with_quant);
|
||||
|
||||
// Adapted from sglang: INT4 W4A8 kernels
|
||||
ops.def(
|
||||
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
|
||||
"Tensor scales) -> (Tensor, Tensor, Tensor)");
|
||||
ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
|
||||
&convert_weight_packed_scale_zp);
|
||||
|
||||
ops.def(
|
||||
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
|
||||
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
|
||||
ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
|
||||
#endif
|
||||
|
||||
// CPU attention kernels
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <cassert>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
@@ -6,14 +6,16 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <torch/all.h>
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
|
||||
@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cute::Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
|
||||
@@ -3,6 +3,14 @@
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
|
||||
// (stable ABI) targets. When compiled under the stable ABI target,
|
||||
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
|
||||
// use torch::stable::Tensor instead.
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#endif
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
@@ -15,6 +23,12 @@
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
using TensorType = torch::stable::Tensor;
|
||||
#else
|
||||
using TensorType = torch::Tensor;
|
||||
#endif
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
@@ -84,7 +98,7 @@ struct ScaledEpilogueBase {
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
static auto args_from_tensor(TensorType const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
@@ -100,7 +114,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<TensorType> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
@@ -158,8 +172,8 @@ struct ScaledEpilogue
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
@@ -203,9 +217,9 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& azp_adj,
|
||||
std::optional<TensorType> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& azp_adj,
|
||||
TensorType const& azp,
|
||||
std::optional<TensorType> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
@@ -1,6 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
// This header is shared between _C (unstable ABI, used by machete) and
|
||||
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
|
||||
// is defined only for the stable target, so we switch includes and types
|
||||
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
|
||||
using TorchTensor = torch::stable::Tensor;
|
||||
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
|
||||
#else
|
||||
#include <torch/all.h>
|
||||
using TorchTensor = torch::Tensor;
|
||||
#define TORCH_UTILS_CHECK TORCH_CHECK
|
||||
#endif
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
@@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
||||
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||
// strides are set to be 0 or 1.
|
||||
template <typename Stride>
|
||||
static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
static inline auto make_cute_layout(TorchTensor const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
|
||||
auto stride = cute::transform_with_idx(
|
||||
Stride{}, [&](auto const& stride_ele, auto const& idx) {
|
||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
|
||||
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
|
||||
auto const& idx) {
|
||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||
|
||||
if (idx < tensor.dim()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
|
||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||
return StrideEle{};
|
||||
} else {
|
||||
if (tensor.size(idx) == 1) {
|
||||
// use 0 stride for dim with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
return tensor.stride(idx);
|
||||
}
|
||||
}
|
||||
if (idx < tensor.dim()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
|
||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||
return StrideEle{};
|
||||
} else {
|
||||
if (tensor.size(idx) == 1) {
|
||||
// use 0 stride for dim with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||
}
|
||||
return StrideEle{};
|
||||
return tensor.stride(idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||
}
|
||||
return StrideEle{};
|
||||
}
|
||||
});
|
||||
|
||||
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
||||
if (idx < tensor.dim())
|
||||
@@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
|
||||
template <typename Stride>
|
||||
static inline auto maybe_make_cute_layout(
|
||||
std::optional<torch::Tensor> const& tensor,
|
||||
std::optional<TorchTensor> const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||
|
||||
@@ -121,12 +136,12 @@ template <typename T>
|
||||
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<c10::Half> {
|
||||
struct equivalent_cutlass_type<torch::headeronly::Half> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<c10::BFloat16> {
|
||||
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
@@ -134,8 +149,8 @@ struct equivalent_cutlass_type<c10::BFloat16> {
|
||||
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
|
||||
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
|
||||
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
|
||||
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
|
||||
template <typename T>
|
||||
struct equivalent_scalar_type {
|
||||
using type = T;
|
||||
@@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::half_t> {
|
||||
using type = c10::Half;
|
||||
using type = torch::headeronly::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
||||
using type = c10::BFloat16;
|
||||
using type = torch::headeronly::BFloat16;
|
||||
};
|
||||
|
||||
// get equivalent c10::ScalarType tag from compile time type
|
||||
// get equivalent torch::headeronly::ScalarType tag from compile time type
|
||||
template <typename T>
|
||||
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
|
||||
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
|
||||
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
|
||||
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||
|
||||
/*
|
||||
@@ -52,7 +54,7 @@ struct ScaledEpilogueBase {
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
static auto args_from_tensor(torch::stable::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
@@ -68,7 +70,8 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(
|
||||
std::optional<torch::stable::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
@@ -117,8 +120,8 @@ struct ScaledEpilogue
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
@@ -160,9 +163,9 @@ struct ScaledEpilogueBias
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -49,6 +49,15 @@
|
||||
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||
|
||||
// Half types dispatch (Half + BFloat16)
|
||||
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
|
||||
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
|
||||
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
|
||||
|
||||
// Boolean dispatch
|
||||
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
if (expr) { \
|
||||
|
||||
@@ -27,4 +27,111 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& output_s,
|
||||
int64_t group_size, double eps, double int8_min,
|
||||
double int8_max);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides, bool per_act_token,
|
||||
bool per_out_ch);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
// FP4/NVFP4 ops
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha);
|
||||
|
||||
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
|
||||
const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets);
|
||||
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
|
||||
|
||||
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout,
|
||||
torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& output_scale);
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor& output_block_scale,
|
||||
torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& input_global_scale);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
@@ -41,7 +40,7 @@ __global__ void get_group_gemm_starts(
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
|
||||
cutlass::Array<cutlass::float_e4m3_t, 8>> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
@@ -66,23 +65,34 @@ __global__ void get_group_gemm_starts(
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
|
||||
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
|
||||
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_group_scales.dtype() ==
|
||||
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
|
||||
TORCH_CHECK(out_tensors.dtype() ==
|
||||
torch::kBFloat16); // only support bf16 for now
|
||||
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
|
||||
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
|
||||
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
|
||||
torch::stable::Tensor& b_group_scales_ptrs,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) {
|
||||
STD_TORCH_CHECK(a_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Int); // int4 8x packed into int32
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(
|
||||
b_group_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
|
||||
// type is e4m3
|
||||
STD_TORCH_CHECK(
|
||||
out_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
// logical k, n
|
||||
@@ -90,15 +100,16 @@ void run_get_group_gemm_starts(
|
||||
int64_t k = a_tensors.size(1);
|
||||
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
|
||||
cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
@@ -14,13 +14,12 @@
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
// vllm includes
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "get_group_starts.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "w4a8_utils.cuh"
|
||||
@@ -168,31 +167,40 @@ struct W4A8GroupedGemmKernel {
|
||||
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
|
||||
"LayoutB_Reordered size must be divisible by 4 bytes");
|
||||
|
||||
static void grouped_mm(
|
||||
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides) {
|
||||
static void grouped_mm(torch::stable::Tensor& out_tensors,
|
||||
const torch::stable::Tensor& a_tensors,
|
||||
const torch::stable::Tensor& b_tensors,
|
||||
const torch::stable::Tensor& a_scales,
|
||||
const torch::stable::Tensor& b_scales,
|
||||
const torch::stable::Tensor& b_group_scales,
|
||||
const int64_t b_group_size,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& problem_sizes_torch,
|
||||
const torch::stable::Tensor& a_strides,
|
||||
const torch::stable::Tensor& b_strides,
|
||||
const torch::stable::Tensor& c_strides,
|
||||
const torch::stable::Tensor& group_scale_strides) {
|
||||
auto device = a_tensors.device();
|
||||
auto device_id = device.index();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device_id);
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(device_id);
|
||||
auto stream = get_current_cuda_stream(device_id);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
int n = static_cast<int>(b_tensors.size(1));
|
||||
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(device);
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor out_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
|
||||
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
|
||||
// get the correct offsets to pass to gemm
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
@@ -247,9 +255,9 @@ struct W4A8GroupedGemmKernel {
|
||||
|
||||
// Allocate workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
torch::stable::Tensor workspace = torch::stable::empty(
|
||||
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
|
||||
device);
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
@@ -294,14 +302,20 @@ using Kernel_256x128_2x1x1_Coop =
|
||||
using Kernel_128x256_2x1x1_Coop =
|
||||
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||
|
||||
void mm_dispatch(
|
||||
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides, const std::string& schedule) {
|
||||
void mm_dispatch(torch::stable::Tensor& out_tensors,
|
||||
const torch::stable::Tensor& a_tensors,
|
||||
const torch::stable::Tensor& b_tensors,
|
||||
const torch::stable::Tensor& a_scales,
|
||||
const torch::stable::Tensor& b_scales,
|
||||
const torch::stable::Tensor& b_group_scales,
|
||||
const int64_t b_group_size,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& a_strides,
|
||||
const torch::stable::Tensor& b_strides,
|
||||
const torch::stable::Tensor& c_strides,
|
||||
const torch::stable::Tensor& group_scale_strides,
|
||||
const std::string& schedule) {
|
||||
if (schedule == "Kernel_128x16_1x1x1_Coop") {
|
||||
Kernel_128x16_1x1x1_Coop::grouped_mm(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||
@@ -358,18 +372,23 @@ void mm_dispatch(
|
||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, group_scale_strides);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
|
||||
STD_TORCH_CHECK(false,
|
||||
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
|
||||
}
|
||||
}
|
||||
|
||||
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
||||
const torch::Tensor& group_scale_strides,
|
||||
void mm(torch::stable::Tensor& out_tensors,
|
||||
const torch::stable::Tensor& a_tensors,
|
||||
const torch::stable::Tensor& b_tensors,
|
||||
const torch::stable::Tensor& a_scales,
|
||||
const torch::stable::Tensor& b_scales,
|
||||
const torch::stable::Tensor& b_group_scales, const int64_t b_group_size,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& a_strides,
|
||||
const torch::stable::Tensor& b_strides,
|
||||
const torch::stable::Tensor& c_strides,
|
||||
const torch::stable::Tensor& group_scale_strides,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// user has specified a schedule
|
||||
if (maybe_schedule) {
|
||||
@@ -406,26 +425,27 @@ void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
||||
a_strides, b_strides, c_strides, group_scale_strides, schedule);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
||||
torch::Tensor const& b_tensors) {
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
|
||||
TORCH_CHECK(b_tensors.is_contiguous());
|
||||
TORCH_CHECK(b_tensors.is_cuda());
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
|
||||
encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
|
||||
STD_TORCH_CHECK(b_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
|
||||
STD_TORCH_CHECK(b_tensors.is_contiguous());
|
||||
STD_TORCH_CHECK(b_tensors.is_cuda());
|
||||
|
||||
int n = static_cast<int>(b_tensors.size(1));
|
||||
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
|
||||
|
||||
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
|
||||
// These misalignments cause silent OOB unless run under Compute Sanitizer.
|
||||
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
|
||||
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
|
||||
STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
|
||||
STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
|
||||
|
||||
// we will store the layout to an int32 tensor;
|
||||
// this is the number of elements we need per layout
|
||||
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
|
||||
|
||||
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
|
||||
torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors);
|
||||
int num_experts = static_cast<int>(b_tensors.size(0));
|
||||
|
||||
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
|
||||
@@ -435,7 +455,7 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
||||
size_t num_int4_elems = 1ull * num_experts * n * k;
|
||||
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
|
||||
num_int4_elems);
|
||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
|
||||
// construct the layout once; assumes each expert has the same layout
|
||||
using LayoutType = LayoutB_Reordered;
|
||||
@@ -456,28 +476,28 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
||||
}
|
||||
|
||||
// save the packed layout to torch tensor so we can re-use it
|
||||
auto cpu_opts =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
|
||||
torch::Tensor layout_cpu =
|
||||
torch::empty({num_experts, layout_width}, cpu_opts);
|
||||
torch::stable::Tensor layout_cpu = torch::stable::empty(
|
||||
{num_experts, layout_width}, torch::headeronly::ScalarType::Int,
|
||||
std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU));
|
||||
|
||||
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
|
||||
int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>();
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
std::memcpy(layout_data + i * layout_width, // dst (int32*)
|
||||
&layout_B_reordered, // src (LayoutType*)
|
||||
sizeof(LayoutType)); // number of bytes
|
||||
}
|
||||
|
||||
torch::Tensor packed_layout =
|
||||
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
|
||||
torch::stable::Tensor packed_layout =
|
||||
torch::stable::to(layout_cpu, b_tensors.device(),
|
||||
/*non_blocking=*/false);
|
||||
|
||||
return {b_tensors_packed, packed_layout};
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_moe_mm", &mm);
|
||||
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm));
|
||||
m.impl("cutlass_encode_and_reorder_int4b_grouped",
|
||||
TORCH_BOX(&encode_and_reorder_int4b));
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8_moe
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -3,14 +3,12 @@
|
||||
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
|
||||
//
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "w4a8_utils.cuh"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <limits>
|
||||
|
||||
@@ -161,31 +159,31 @@ struct W4A8GemmKernel {
|
||||
using StrideD = typename GemmKernelShuffled::StrideD;
|
||||
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||
|
||||
static torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type) {
|
||||
static torch::stable::Tensor mm(
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, // already packed
|
||||
torch::stable::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||
torch::stable::Tensor const& token_scales,
|
||||
std::optional<torch::headeronly::ScalarType> const& maybe_out_type) {
|
||||
// TODO: param validation
|
||||
int m = A.size(0);
|
||||
int k = A.size(1);
|
||||
int n = B.size(1);
|
||||
|
||||
// safely cast group_size to int
|
||||
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
|
||||
"group_size out of supported range for int: ", group_size);
|
||||
STD_TORCH_CHECK(
|
||||
group_size > 0 && group_size <= std::numeric_limits<int>::max(),
|
||||
"group_size out of supported range for int: ", group_size);
|
||||
int const group_size_int = static_cast<int>(group_size);
|
||||
|
||||
// Allocate output
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
A.get_device_index());
|
||||
auto device = A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
torch::Tensor D =
|
||||
torch::empty({m, n}, torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<ElementD>)
|
||||
.device(device));
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
torch::stable::Tensor D = torch::stable::empty(
|
||||
{m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device);
|
||||
// prepare arg pointers
|
||||
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
@@ -237,9 +235,9 @@ struct W4A8GemmKernel {
|
||||
|
||||
// Workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
torch::stable::Tensor workspace = torch::stable::empty(
|
||||
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
|
||||
device);
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
@@ -269,14 +267,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
|
||||
|
||||
torch::Tensor mm_dispatch(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
const std::string& schedule) {
|
||||
torch::stable::Tensor mm_dispatch(
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, // already packed
|
||||
torch::stable::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||
torch::stable::Tensor const& token_scales,
|
||||
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
|
||||
const std::string& schedule) {
|
||||
if (schedule == "256x128_1x1x1") {
|
||||
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
@@ -318,17 +316,18 @@ torch::Tensor mm_dispatch(torch::Tensor const& A,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||
STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
torch::stable::Tensor mm(
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, // already packed
|
||||
torch::stable::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||
torch::stable::Tensor const& token_scales,
|
||||
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// requested a specific schedule
|
||||
if (maybe_schedule) {
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
@@ -378,14 +377,15 @@ torch::Tensor mm(torch::Tensor const& A,
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pre-processing utils
|
||||
// ----------------------------------------------------------------------------
|
||||
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
TORCH_CHECK(scales.is_cuda());
|
||||
torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
|
||||
STD_TORCH_CHECK(scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(scales.is_contiguous());
|
||||
STD_TORCH_CHECK(scales.is_cuda());
|
||||
|
||||
auto packed_scales = torch::empty(
|
||||
{scales.numel() * ScalePackSize},
|
||||
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
|
||||
auto packed_scales =
|
||||
torch::stable::empty({scales.numel() * ScalePackSize},
|
||||
scales.scalar_type(), std::nullopt, scales.device());
|
||||
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
|
||||
auto packed_scales_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
|
||||
@@ -396,15 +396,16 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
return packed_scales;
|
||||
}
|
||||
|
||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(B.dim() == 2);
|
||||
torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
|
||||
STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(B.dim() == 2);
|
||||
|
||||
torch::Tensor B_packed = torch::empty_like(B);
|
||||
torch::stable::Tensor B_packed = torch::stable::empty_like(B);
|
||||
|
||||
int k = B.size(0) * PackFactor; // logical k
|
||||
int n = B.size(1);
|
||||
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
|
||||
STD_TORCH_CHECK((n * k) % 32 == 0,
|
||||
"need multiples of 32 int4s for 16B chunks");
|
||||
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
||||
@@ -415,16 +416,17 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
|
||||
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
|
||||
n * k);
|
||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||
|
||||
return B_packed;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_mm", &mm);
|
||||
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
|
||||
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm));
|
||||
m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8));
|
||||
m.impl("cutlass_encode_and_reorder_int4b",
|
||||
TORCH_BOX(&encode_and_reorder_int4b));
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8
|
||||
} // namespace vllm::cutlass_w4a8
|
||||
@@ -14,16 +14,15 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
@@ -118,17 +117,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
torch::Tensor& output_sf,
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& input_sf) {
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(
|
||||
torch::stable::Tensor& output, // [..., d]
|
||||
torch::stable::Tensor& output_sf,
|
||||
torch::stable::Tensor& input, // [..., 2 * d]
|
||||
torch::stable::Tensor& input_sf) {
|
||||
int32_t m = input.size(0);
|
||||
int32_t n = input.size(1) / 2;
|
||||
|
||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
STD_TORCH_CHECK(
|
||||
input.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
|
||||
int multiProcessorCount =
|
||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||
@@ -136,8 +137,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
auto stream = get_current_cuda_stream(input.get_device_index());
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
@@ -149,7 +151,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
@@ -14,14 +14,12 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "core/registration.h"
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
@@ -122,7 +120,7 @@ __global__ void __get_group_gemm_starts(
|
||||
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
|
||||
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
|
||||
LayoutSFB, ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
|
||||
LayoutSFA, LayoutSFB, ScaleConfig> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
@@ -150,50 +148,64 @@ __global__ void __get_group_gemm_starts(
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_group_gemm_starts(
|
||||
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
|
||||
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
|
||||
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
|
||||
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
|
||||
int64_t c_stride_val,
|
||||
/*these are used for their base addresses*/
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& alphas,
|
||||
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
|
||||
torch::Tensor const& problem_sizes, int M, int N, int K) {
|
||||
void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
|
||||
const torch::stable::Tensor& b_starts,
|
||||
const torch::stable::Tensor& out_starts,
|
||||
const torch::stable::Tensor& a_scales_starts,
|
||||
const torch::stable::Tensor& b_scales_starts,
|
||||
const torch::stable::Tensor& alpha_starts,
|
||||
const torch::stable::Tensor& layout_sfa,
|
||||
const torch::stable::Tensor& layout_sfb,
|
||||
const torch::stable::Tensor& a_strides,
|
||||
const torch::stable::Tensor& b_strides,
|
||||
const torch::stable::Tensor& c_strides,
|
||||
int64_t a_stride_val, int64_t b_stride_val,
|
||||
int64_t c_stride_val,
|
||||
/*these are used for their base addresses*/
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& out_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& alphas,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& sf_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
int M, int N, int K) {
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
TORCH_CHECK(out_tensors.size(1) == N,
|
||||
"Output tensor shape doesn't match expected shape");
|
||||
TORCH_CHECK(K / 2 == b_tensors.size(2),
|
||||
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
||||
" dimension must match");
|
||||
STD_TORCH_CHECK(out_tensors.size(1) == N,
|
||||
"Output tensor shape doesn't match expected shape");
|
||||
STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
|
||||
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
||||
" dimension must match");
|
||||
if (false) {
|
||||
}
|
||||
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
|
||||
// ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
||||
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
|
||||
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
cutlass::float_e2m1_t, cutlass::float_ue4m3_t,
|
||||
torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
|
||||
cutlass::float_ue4m3_t, torch::kFloat16,
|
||||
half, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
cutlass::float_ue4m3_t,
|
||||
torch::headeronly::ScalarType::Half, half,
|
||||
LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||
int N, int K) {
|
||||
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||
using ElementType = cutlass::float_e2m1_t;
|
||||
@@ -272,20 +284,40 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::stable::Tensor a_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor out_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor a_scales_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_scales_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor alpha_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor layout_sfa = torch::stable::empty(
|
||||
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||
a.device());
|
||||
torch::stable::Tensor layout_sfb = torch::stable::empty(
|
||||
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||
a.device());
|
||||
torch::stable::Tensor a_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor c_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||
@@ -308,7 +340,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||
hw_info.device_id = a.get_device();
|
||||
hw_info.device_id = a.get_device_index();
|
||||
static std::unordered_map<int, int> cached_sm_counts;
|
||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||
cached_sm_counts[hw_info.device_id] =
|
||||
@@ -350,32 +382,35 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
scheduler};
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, a.device());
|
||||
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||
STD_TORCH_CHECK(
|
||||
can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||
|
||||
// Run the GEMM
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM: status=", (int)status,
|
||||
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
||||
" M=", M, " N=", N, " K=", K);
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM: status=", (int)status,
|
||||
" workspace_size=", workspace_size,
|
||||
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
|
||||
|
||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||
int N, int K) {
|
||||
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||
using ElementType = cutlass::float_e2m1_t;
|
||||
@@ -446,20 +481,40 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::stable::Tensor a_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor out_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor a_scales_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_scales_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor alpha_ptrs =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor layout_sfa = torch::stable::empty(
|
||||
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||
a.device());
|
||||
torch::stable::Tensor layout_sfb = torch::stable::empty(
|
||||
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||
a.device());
|
||||
torch::stable::Tensor a_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor b_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
torch::stable::Tensor c_strides1 =
|
||||
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||
std::nullopt, a.device());
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||
@@ -480,7 +535,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||
hw_info.device_id = a.get_device();
|
||||
hw_info.device_id = a.get_device_index();
|
||||
static std::unordered_map<int, int> cached_sm_counts;
|
||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||
cached_sm_counts[hw_info.device_id] =
|
||||
@@ -523,33 +578,36 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
scheduler};
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, a.device());
|
||||
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||
STD_TORCH_CHECK(
|
||||
can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||
|
||||
// Run the GEMM
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM: status=", (int)status,
|
||||
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
||||
" M=", M, " N=", N, " K=", K);
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize GEMM: status=", (int)status,
|
||||
" workspace_size=", workspace_size,
|
||||
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
|
||||
|
||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void run_fp4_blockwise_scaled_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||
int N, int K) {
|
||||
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||
if (version_num >= 120 && version_num < 130) {
|
||||
@@ -567,7 +625,7 @@ void run_fp4_blockwise_scaled_group_mm(
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 100 or 120");
|
||||
@@ -575,26 +633,31 @@ void run_fp4_blockwise_scaled_group_mm(
|
||||
|
||||
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
||||
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||
#endif
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||
": Inconsistency of torch::stable::Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
||||
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
||||
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
|
||||
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
|
||||
const torch::stable::Tensor& a,
|
||||
const torch::stable::Tensor& b,
|
||||
const torch::stable::Tensor& a_blockscale,
|
||||
const torch::stable::Tensor& b_blockscales,
|
||||
const torch::stable::Tensor& alphas,
|
||||
const torch::stable::Tensor& problem_sizes,
|
||||
const torch::stable::Tensor& expert_offsets,
|
||||
const torch::stable::Tensor& sf_offsets) {
|
||||
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
||||
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
||||
// Input validation
|
||||
@@ -602,30 +665,34 @@ void cutlass_fp4_group_mm(
|
||||
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
||||
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
|
||||
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
|
||||
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
|
||||
CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas");
|
||||
|
||||
TORCH_CHECK(a_blockscale.dim() == 2,
|
||||
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
||||
" k // group_size], observed rank: ",
|
||||
a_blockscale.dim())
|
||||
TORCH_CHECK(b_blockscales.dim() == 3,
|
||||
"expected b_blockscale to be of shape: "
|
||||
" [num_experts, n, k // group_size], observed rank: ",
|
||||
b_blockscales.dim())
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have the shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32.");
|
||||
STD_TORCH_CHECK(
|
||||
a_blockscale.dim() == 2,
|
||||
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
||||
" k // group_size], observed rank: ",
|
||||
a_blockscale.dim())
|
||||
STD_TORCH_CHECK(b_blockscales.dim() == 3,
|
||||
"expected b_blockscale to be of shape: "
|
||||
" [num_experts, n, k // group_size], observed rank: ",
|
||||
b_blockscales.dim())
|
||||
STD_TORCH_CHECK(problem_sizes.dim() == 2,
|
||||
"problem_sizes must be a 2D tensor");
|
||||
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have the shape (num_experts, 3)");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32.");
|
||||
|
||||
int M = static_cast<int>(a.size(0));
|
||||
int N = static_cast<int>(b.size(1));
|
||||
int E = static_cast<int>(b.size(0));
|
||||
int K = static_cast<int>(2 * b.size(2));
|
||||
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
|
||||
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||
expert_offsets, sf_offsets, M, N, K);
|
||||
@@ -633,7 +700,7 @@ void cutlass_fp4_group_mm(
|
||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 120 && version_num < 130) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
|
||||
output.scalar_type());
|
||||
}
|
||||
@@ -643,7 +710,7 @@ void cutlass_fp4_group_mm(
|
||||
expert_offsets, sf_offsets, M, N, K);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
||||
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
|
||||
@@ -651,6 +718,6 @@ void cutlass_fp4_group_mm(
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm));
|
||||
}
|
||||
@@ -14,16 +14,15 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
@@ -327,25 +326,28 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
} // namespace vllm
|
||||
|
||||
/*Quantization entry for fp4 experts quantization*/
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m);
|
||||
|
||||
constexpr auto HALF = at::ScalarType::Half;
|
||||
constexpr auto BF16 = at::ScalarType::BFloat16;
|
||||
constexpr auto FLOAT = at::ScalarType::Float;
|
||||
constexpr auto INT = at::ScalarType::Int;
|
||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||
constexpr auto HALF = torch::headeronly::ScalarType::Half;
|
||||
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
|
||||
constexpr auto FLOAT = torch::headeronly::ScalarType::Float;
|
||||
constexpr auto INT = torch::headeronly::ScalarType::Int;
|
||||
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
|
||||
|
||||
// Common validation for fp4 experts quantization entry points.
|
||||
static void validate_fp4_experts_quant_inputs(
|
||||
torch::Tensor const& output, torch::Tensor const& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
||||
torch::stable::Tensor const& output,
|
||||
torch::stable::Tensor const& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
||||
int64_t k) {
|
||||
CHECK_INPUT(output, "output");
|
||||
CHECK_INPUT(output_scale, "output_scale");
|
||||
@@ -354,41 +356,42 @@ static void validate_fp4_experts_quant_inputs(
|
||||
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
|
||||
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
STD_TORCH_CHECK(output.dim() == 2);
|
||||
STD_TORCH_CHECK(output_scale.dim() == 2);
|
||||
STD_TORCH_CHECK(input.dim() == 2);
|
||||
STD_TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||
TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
STD_TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output.size(0) == m_topk);
|
||||
TORCH_CHECK(output.size(1) == k / 2);
|
||||
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
STD_TORCH_CHECK(output.size(0) == m_topk);
|
||||
STD_TORCH_CHECK(output.size(1) == k / 2);
|
||||
int scales_k = k / BLOCK_SIZE;
|
||||
// 4 means the swizzle requirement by nvidia nvfp4.
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
|
||||
@@ -397,11 +400,11 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
output_scale_offset_by_experts, m_topk, k);
|
||||
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
|
||||
@@ -413,14 +416,15 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||
auto m_topk = input.size(0);
|
||||
// Input has gate || up layout, so k = input.size(1) / 2
|
||||
auto k_times_2 = input.size(1);
|
||||
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
||||
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
||||
auto k = k_times_2 / 2;
|
||||
|
||||
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
||||
@@ -428,11 +432,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
output_scale_offset_by_experts, m_topk, k);
|
||||
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
|
||||
175
csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
Normal file
175
csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
Normal file
@@ -0,0 +1,175 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& output_sf,
|
||||
torch::stable::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& output_sf,
|
||||
torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& input_sf);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
static bool nvfp4_quant_sm_supported() {
|
||||
const int32_t sm = get_sm_version_num();
|
||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||
if (sm >= 100 && sm < 120) return true;
|
||||
#endif
|
||||
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||
if (sm >= 120 && sm < 130) return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout,
|
||||
torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& output_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled nvfp4 quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||
is_sf_swizzled_layout);
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No compiled nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
|
||||
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
int64_t n = input.size(-1);
|
||||
int64_t m = input.numel() / n;
|
||||
auto device = input.device();
|
||||
|
||||
// Two fp4 values packed into a uint8
|
||||
auto output = torch::stable::empty(
|
||||
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
|
||||
|
||||
torch::stable::Tensor output_sf;
|
||||
if (is_sf_swizzled_layout) {
|
||||
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
|
||||
output_sf = torch::stable::empty(
|
||||
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
|
||||
} else {
|
||||
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
|
||||
torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
}
|
||||
|
||||
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
|
||||
output_sf);
|
||||
return {output, output_sf};
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled nvfp4 experts quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
|
||||
torch::stable::Tensor& output_sf,
|
||||
torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& input_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& input_global_scale,
|
||||
torch::stable::Tensor const& input_offset_by_experts,
|
||||
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled silu_and_mul nvfp4 experts quantization kernel "
|
||||
"for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
|
||||
}
|
||||
@@ -14,16 +14,16 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
@@ -173,18 +173,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& output_sf,
|
||||
torch::Tensor const& input_sf,
|
||||
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor const& output_sf,
|
||||
torch::stable::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
int32_t m = input.size(0);
|
||||
int32_t n = input.size(1);
|
||||
|
||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
STD_TORCH_CHECK(
|
||||
input.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
|
||||
int multiProcessorCount =
|
||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||
@@ -192,8 +193,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
auto stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
|
||||
|
||||
@@ -213,15 +215,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
// NOTE: We don't support e8m0 scales at this moment.
|
||||
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||
m, n, num_padded_cols, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||
m, n, num_padded_cols, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
} else {
|
||||
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
|
||||
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
|
||||
@@ -229,15 +231,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
// NOTE: We don't support e8m0 scales at this moment.
|
||||
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
|
||||
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
|
||||
input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
|
||||
const torch::stable::Tensor& A,
|
||||
const torch::stable::Tensor& B,
|
||||
const torch::stable::Tensor& A_sf,
|
||||
const torch::stable::Tensor& B_sf,
|
||||
const torch::stable::Tensor& alpha) {
|
||||
// Make sure we're on A's device.
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
A.get_device_index());
|
||||
const int32_t sm = get_sm_version_num();
|
||||
|
||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||
if (sm >= 100 && sm < 120) {
|
||||
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||
if (sm >= 120 && sm < 130) {
|
||||
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled nvfp4 mm kernel for SM ", sm,
|
||||
". Recompile with CUDA >= 12.8 and CC >= 100.");
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
||||
int runtimeVersion;
|
||||
cudaRuntimeGetVersion(&runtimeVersion);
|
||||
if (runtimeVersion < 12080) return false;
|
||||
// Only report support when the SM-specific kernel was actually compiled in,
|
||||
// so the Python-side backend selector does not choose CUTLASS and then hit
|
||||
// TORCH_CHECK_NOT_IMPLEMENTED (or worse, fall through to Marlin).
|
||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||
if (cuda_device_capability >= 100 && cuda_device_capability < 120)
|
||||
return true;
|
||||
#endif
|
||||
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||
if (cuda_device_capability >= 120 && cuda_device_capability < 130)
|
||||
return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
@@ -14,10 +14,9 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
@@ -127,8 +126,9 @@ struct Fp4GemmSm100 {
|
||||
|
||||
template <typename Config>
|
||||
typename Config::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||
torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
|
||||
int64_t M, int64_t N, int64_t K) {
|
||||
using ElementA = typename Config::Gemm::ElementA;
|
||||
using ElementB = typename Config::Gemm::ElementB;
|
||||
@@ -174,19 +174,20 @@ typename Config::Gemm::Arguments args_from_options(
|
||||
}
|
||||
|
||||
template <typename Config>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
typename Config::Gemm gemm;
|
||||
|
||||
auto arguments =
|
||||
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, A.device());
|
||||
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
@@ -197,12 +198,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
|
||||
// Dispatch function to select appropriate config based on M
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int64_t m,
|
||||
int64_t n, int64_t k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (mp2 <= 16) {
|
||||
@@ -222,61 +224,65 @@ void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
|
||||
#else
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int64_t m,
|
||||
int64_t n, int64_t k, cudaStream_t stream) {
|
||||
STD_TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||
": Inconsistency of torch::stable::Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha) {
|
||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||
|
||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||
|
||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
||||
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
|
||||
|
||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
|
||||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
|
||||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
|
||||
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
STD_TORCH_CHECK(A.size(1) == B.size(1),
|
||||
"a and b shapes cannot be multiplied (", A.size(0), "x",
|
||||
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
|
||||
|
||||
auto const m = A.sizes()[0];
|
||||
auto const n = B.sizes()[0];
|
||||
auto const k = A.sizes()[1] * 2;
|
||||
auto const m = A.size(0);
|
||||
auto const n = B.size(0);
|
||||
auto const k = A.size(1) * 2;
|
||||
|
||||
constexpr int alignment = 32;
|
||||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
|
||||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
|
||||
"), k: ", k, ".");
|
||||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
|
||||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
|
||||
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
|
||||
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
|
||||
"), k: ", k, ".");
|
||||
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
|
||||
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
|
||||
").");
|
||||
|
||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||
int rounded_m = round_up(m, 128);
|
||||
@@ -285,33 +291,34 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
// integer.
|
||||
int rounded_k = round_up(k / 16, 4);
|
||||
|
||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
|
||||
"x", B_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
|
||||
A_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
|
||||
B_sf.sizes()[1], ")");
|
||||
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
|
||||
B_sf.size(1), ")");
|
||||
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
|
||||
A_sf.size(1), ")");
|
||||
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
|
||||
B_sf.size(1), ")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
auto out_dtype = D.scalar_type();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
A.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
if (out_dtype == torch::headeronly::ScalarType::Half) {
|
||||
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
||||
k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
} else if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
||||
m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
|
||||
")");
|
||||
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (",
|
||||
out_dtype, ")");
|
||||
}
|
||||
}
|
||||
@@ -14,10 +14,9 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
@@ -34,19 +33,20 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||
": Inconsistency of torch::stable::Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||
|
||||
struct sm120_fp4_config_M256 {
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
@@ -109,12 +109,13 @@ struct Fp4GemmSm120 {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
|
||||
at::Tensor const& B,
|
||||
at::Tensor const& A_sf,
|
||||
at::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int M,
|
||||
int N, int K) {
|
||||
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha,
|
||||
int M, int N, int K) {
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
|
||||
}
|
||||
|
||||
template <typename Gemm>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int M, int N, int K,
|
||||
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int M, int N, int K,
|
||||
cudaStream_t stream) {
|
||||
Gemm gemm;
|
||||
|
||||
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, A.device());
|
||||
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int m, int n,
|
||||
int k, cudaStream_t stream) {
|
||||
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int m,
|
||||
int n, int k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
if (mp2 <= 256) {
|
||||
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
|
||||
@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int m, int n,
|
||||
int k, cudaStream_t stream) {
|
||||
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha, int m,
|
||||
int n, int k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
if (mp2 <= 256) {
|
||||
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
|
||||
@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
|
||||
torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& B,
|
||||
torch::stable::Tensor const& A_sf,
|
||||
torch::stable::Tensor const& B_sf,
|
||||
torch::stable::Tensor const& alpha) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||
@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||
|
||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
||||
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
|
||||
|
||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
|
||||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
|
||||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
|
||||
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
STD_TORCH_CHECK(A.size(1) == B.size(1),
|
||||
"a and b shapes cannot be multiplied (", A.size(0), "x",
|
||||
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
|
||||
|
||||
auto const m = A.sizes()[0];
|
||||
auto const n = B.sizes()[0];
|
||||
auto const k = A.sizes()[1] * 2;
|
||||
auto const m = A.size(0);
|
||||
auto const n = B.size(0);
|
||||
auto const k = A.size(1) * 2;
|
||||
|
||||
constexpr int alignment = 32;
|
||||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
|
||||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
|
||||
"), k: ", k, ".");
|
||||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
|
||||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
|
||||
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
|
||||
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
|
||||
"), k: ", k, ".");
|
||||
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
|
||||
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
|
||||
").");
|
||||
|
||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||
int rounded_m = round_up(m, 128);
|
||||
@@ -248,38 +254,39 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
||||
// integer.
|
||||
int rounded_k = round_up(k / 16, 4);
|
||||
|
||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
|
||||
"x", B_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
|
||||
A_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
|
||||
B_sf.sizes()[1], ")");
|
||||
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
|
||||
B_sf.size(1), ")");
|
||||
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
|
||||
A_sf.size(1), ")");
|
||||
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
|
||||
B_sf.size(1), ")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
auto out_dtype = D.scalar_type();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
A.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
|
||||
|
||||
if (out_dtype == at::ScalarType::BFloat16) {
|
||||
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
||||
stream);
|
||||
} else if (out_dtype == at::ScalarType::Half) {
|
||||
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
|
||||
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
|
||||
out_dtype, ")");
|
||||
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
|
||||
out_dtype, ")");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
STD_TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@
|
||||
#include <cuda_fp8.h>
|
||||
#include <utility>
|
||||
|
||||
#include "../../cuda_vec_utils.cuh"
|
||||
#include "cuda_vec_utils.cuh"
|
||||
|
||||
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
|
||||
CUDA_VERSION >= 12090
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@@ -25,14 +26,14 @@
|
||||
namespace vllm::c3x {
|
||||
|
||||
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||
torch::Tensor const& a, torch::Tensor const& b) {
|
||||
torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
return {m, n, k, 1};
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(
|
||||
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||
torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
@@ -50,19 +51,20 @@ void cutlass_gemm_caller(
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
@@ -4,13 +4,12 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_azp_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
int M = a.size(0);
|
||||
if (M <= 256) {
|
||||
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
|
||||
@@ -0,0 +1,23 @@
|
||||
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
|
||||
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||
STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||
|
||||
StrideA a_stride;
|
||||
StrideB b_stride;
|
||||
@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
// TODO: better heuristics
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||
@@ -1,52 +1,57 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias,
|
||||
void dispatch_scaled_mm(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias,
|
||||
Fp8Func fp8_func, Int8Func int8_func,
|
||||
BlockwiseFunc blockwise_func) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
|
||||
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
|
||||
fp8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
||||
int8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false, "Int8 not supported on SM", version_num,
|
||||
". Use FP8 quantization instead, or run on older arch (SM < 100).");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
a.size(0) == a_scales.size(0) &&
|
||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
||||
"a_scale_group_shape must be [1, 128].");
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
||||
"b_scale_group_shape must be [128, 128].");
|
||||
}
|
||||
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
blockwise_func(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
} // namespace vllm
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
EpilogueArgs&&... args) {
|
||||
inline void cutlass_gemm_sm100_fp8_dispatch(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm100_fp8_config_default<InType, OutType,
|
||||
@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
}
|
||||
|
||||
template <bool EnableBias, typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
int M = a.size(0);
|
||||
|
||||
@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
@@ -235,8 +237,9 @@ struct sm90_fp8_config_M16_N8192 {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_gemm_caller_sm90_fp8(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -280,15 +283,15 @@ void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
EpilogueArgs&&... args) {
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config_default<InType, OutType,
|
||||
@@ -347,22 +350,24 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
}
|
||||
|
||||
template <bool EnableBias, typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
@@ -87,13 +89,13 @@ struct sm90_int8_config_M32_NSmall {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_int8_config_default<InType, OutType,
|
||||
@@ -142,19 +144,19 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
@@ -51,32 +51,39 @@ __global__ void get_group_gemm_starts(
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
|
||||
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
|
||||
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
STD_TORCH_CHECK(a_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
|
||||
cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
@@ -84,13 +85,17 @@ struct cutlass_3x_group_gemm {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_group_gemm_caller(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -98,16 +103,20 @@ void cutlass_group_gemm_caller(
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
auto device = a_tensors.device();
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor out_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||
@@ -156,7 +165,7 @@ void cutlass_group_gemm_caller(
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
int device_id = a_tensors.device().index();
|
||||
int device_id = a_tensors.get_device_index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
@@ -170,9 +179,9 @@ void cutlass_group_gemm_caller(
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
@@ -1,7 +1,8 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
@@ -62,21 +63,27 @@ struct sm100_fp8_config_N8192 {
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
@@ -107,14 +114,18 @@ void run_cutlass_moe_mm_sm100(
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void dispatch_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
@@ -127,13 +138,17 @@ void dispatch_moe_mm_sm100(
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
@@ -1,7 +1,8 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
@@ -103,21 +104,27 @@ struct sm90_fp8_config_N8192 {
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
@@ -163,14 +170,18 @@ void run_cutlass_moe_mm_sm90(
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
@@ -185,13 +196,17 @@ void dispatch_moe_mm_sm90(
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
@@ -1,9 +1,11 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -110,19 +112,22 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab, const bool is_gated) {
|
||||
inline void launch_compute_problem_sizes(const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab,
|
||||
const bool is_gated) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
|
||||
auto const* topk_ptr = topk_ids.const_data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.mutable_data_ptr<int32_t>();
|
||||
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
@@ -171,46 +176,53 @@ __global__ void compute_problem_sizes_from_expert_offsets(
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
|
||||
"expert_first_token_offset must be int64");
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab) {
|
||||
STD_TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
STD_TORCH_CHECK(expert_first_token_offset.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
|
||||
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
|
||||
problem_sizes2.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
STD_TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes1.scalar_type() == torch::headeronly::ScalarType::Int &&
|
||||
problem_sizes2.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
STD_TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
STD_TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
STD_TORCH_CHECK(problem_sizes1.size(0) == problem_sizes2.size(0) &&
|
||||
problem_sizes1.size(1) == problem_sizes2.size(1),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
|
||||
int64_t const num_experts64 = problem_sizes1.size(0);
|
||||
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
|
||||
STD_TORCH_CHECK(
|
||||
expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
STD_TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
STD_TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX,
|
||||
"n and k must fit in int32");
|
||||
|
||||
int const num_experts = static_cast<int>(num_experts64);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(
|
||||
expert_first_token_offset.device().index());
|
||||
auto stream =
|
||||
get_current_cuda_stream(expert_first_token_offset.get_device_index());
|
||||
|
||||
int const threads = (num_experts < 256) ? num_experts : 256;
|
||||
int const blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
auto const* offsets_ptr = expert_first_token_offset.const_data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes_from_expert_offsets<SwapAB>
|
||||
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
||||
num_experts, static_cast<int>(n),
|
||||
@@ -219,16 +231,19 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
auto device = topk_ids.device();
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
torch::stable::Tensor atomic_buffer = torch::stable::new_zeros(
|
||||
topk_ids, {num_experts}, torch::headeronly::ScalarType::Int);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
@@ -290,11 +305,13 @@ __global__ void compute_batched_moe_data(
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
auto stream = get_current_cuda_stream(expert_offsets.get_device_index());
|
||||
|
||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||
@@ -311,4 +328,4 @@ void get_cutlass_batched_moe_mm_data_caller(
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
}
|
||||
}
|
||||
220
csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu
Normal file
220
csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu
Normal file
@@ -0,0 +1,220 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
*/
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
#pragma once
|
||||
#include <stddef.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
@@ -95,8 +96,9 @@ struct cutlass_2x_gemm {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
@@ -149,11 +151,12 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto device = a.device();
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
@@ -161,9 +164,9 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void fallback_cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
// In some cases, the GPU isn't able to accommodate the
|
||||
// shared memory requirements of the Gemm. In such cases, use
|
||||
@@ -180,8 +183,8 @@ inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||
std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
STD_TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
return cutlass_gemm_caller<FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -70,13 +72,13 @@ struct sm75_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm75_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -72,13 +74,13 @@ struct sm80_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm80_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
@@ -34,10 +36,12 @@ struct sm89_fp8_config_default {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -84,10 +88,12 @@ struct sm89_fp8_config_M256 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -125,10 +131,12 @@ struct sm89_fp8_config_M128 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -173,10 +181,12 @@ struct sm89_fp8_config_M64 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -227,10 +237,12 @@ struct sm89_fp8_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -280,10 +292,12 @@ struct sm89_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -326,13 +340,15 @@ struct sm89_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -32,10 +34,11 @@ struct sm89_int8_config_default {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -88,10 +91,11 @@ struct sm89_int8_config_M256 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -143,10 +147,11 @@ struct sm89_int8_config_M128 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -193,10 +198,11 @@ struct sm89_int8_config_M64 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -234,10 +240,11 @@ struct sm89_int8_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -276,10 +283,11 @@ struct sm89_int8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -311,13 +319,13 @@ struct sm89_int8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm100_fp8,
|
||||
nullptr, // int8 not supported on SM100
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm120_fp8,
|
||||
nullptr, // int8 not supported on SM120
|
||||
@@ -0,0 +1,38 @@
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm90_fp8,
|
||||
vllm::cutlass_scaled_mm_sm90_int8,
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,451 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
} else if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
||||
// or CUDA 12.8 and SM100 (Blackwell)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
}
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
if (version_num >= 120) {
|
||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100 && version_num < 120) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 75) {
|
||||
// Turing
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides, bool per_act_token,
|
||||
bool per_out_ch) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
if (version_num >= 100 && version_num < 110) {
|
||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
blockscale_offsets, is_gated);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
||||
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_batched_moe_mm_data: no "
|
||||
"cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
// bias, azp, azp_adj are all 1d
|
||||
// bias and azp_adj have n elements, azp has m elements
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(!azp ||
|
||||
azp->scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
c.scalar_type());
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
STD_TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@@ -31,6 +31,174 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||
"()");
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
|
||||
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
||||
|
||||
// CUTLASS w8a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
||||
" bool per_out_ch) -> ()");
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
||||
// (token start indices of each expert). In addition to this, it computes
|
||||
// problem sizes for each expert's multiplication used by the two mms called
|
||||
// from fused MoE operation, and arrays with permutations required to shuffle
|
||||
// and de-shuffle the input/output of the fused operation.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k, Tensor? blockscale_offsets, "
|
||||
" bool is_gated) -> ()");
|
||||
|
||||
// compute per-expert problem sizes from expert_first_token_offset
|
||||
// produced by vLLM's moe_permute kernel
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
||||
" Tensor expert_first_token_offset, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int n, int k, bool swap_ab) -> ()");
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM in batched expert format. It takes expert_num_tokens
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
// expert). In addition to this, it computes problem sizes for each expert's
|
||||
// multiplication used by the two mms called from fused MoE operation.
|
||||
ops.def(
|
||||
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" Tensor expert_num_tokens, "
|
||||
" int num_local_experts, int padded_m, "
|
||||
" int n, int k) -> ()");
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
||||
|
||||
// Out variant
|
||||
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
|
||||
// This registration is now migrated to stable ABI
|
||||
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
|
||||
// yet in torch/headeronly), the tag should be applied from Python
|
||||
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
|
||||
// with the .impl remaining in C++.
|
||||
// See pytorch/pytorch#176117.
|
||||
ops.def(
|
||||
"scaled_fp4_quant.out(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
||||
"-> ()");
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
|
||||
// Fused SiLU+Mul+NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
|
||||
"output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
|
||||
// Fused SiLU+Mul+NVFP4 quantization.
|
||||
ops.def(
|
||||
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
|
||||
"Tensor input, Tensor input_global_scale) -> ()");
|
||||
|
||||
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
|
||||
// of the given capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_mm("
|
||||
" Tensor A,"
|
||||
" Tensor B,"
|
||||
" Tensor group_scales,"
|
||||
" int group_size,"
|
||||
" Tensor channel_scales,"
|
||||
" Tensor token_scales,"
|
||||
" ScalarType? out_type,"
|
||||
" str? maybe_schedule"
|
||||
") -> Tensor");
|
||||
|
||||
// pack scales
|
||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||
|
||||
// encode and reorder weight matrix
|
||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||
|
||||
// CUTLASS w4a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_moe_mm("
|
||||
" Tensor! out_tensors,"
|
||||
" Tensor a_tensors,"
|
||||
" Tensor b_tensors,"
|
||||
" Tensor a_scales,"
|
||||
" Tensor b_scales,"
|
||||
" Tensor b_group_scales,"
|
||||
" int b_group_size,"
|
||||
" Tensor expert_offsets,"
|
||||
" Tensor problem_sizes,"
|
||||
" Tensor a_strides,"
|
||||
" Tensor b_strides,"
|
||||
" Tensor c_strides,"
|
||||
" Tensor group_scale_strides,"
|
||||
" str? maybe_schedule"
|
||||
") -> ()");
|
||||
|
||||
ops.def(
|
||||
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
|
||||
"Tensor)");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -46,6 +214,45 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
TORCH_BOX(&per_token_group_quant_8bit_packed));
|
||||
ops.impl("per_token_group_quant_int8",
|
||||
TORCH_BOX(&per_token_group_quant_int8));
|
||||
|
||||
// CUTLASS scaled_mm ops
|
||||
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
|
||||
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
|
||||
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
|
||||
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
|
||||
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
|
||||
ops.impl("get_cutlass_batched_moe_mm_data",
|
||||
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
|
||||
|
||||
// FP4/NVFP4 ops
|
||||
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
|
||||
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
|
||||
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
|
||||
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
|
||||
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
|
||||
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
|
||||
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
|
||||
|
||||
// W4A8 ops: impl registrations are in the source files
|
||||
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
|
||||
#endif
|
||||
}
|
||||
|
||||
// These capability-check functions take only primitive args (no tensors), so
|
||||
// there is no device to dispatch on. CompositeExplicitAutograd makes them
|
||||
// available for all backends. This is the stable ABI equivalent of calling
|
||||
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||
#ifndef USE_ROCM
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
|
||||
ops.impl("cutlass_group_gemm_supported",
|
||||
TORCH_BOX(&cutlass_group_gemm_supported));
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
|
||||
ops.impl("cutlass_scaled_mm_supports_fp4",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
|
||||
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
|
||||
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
|
||||
|
||||
// Utility to get the current CUDA stream for a given device using stable APIs.
|
||||
// Returns a cudaStream_t for use in kernel launches.
|
||||
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
|
||||
|
||||
@@ -21,7 +21,7 @@ struct SSMParamsBase {
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
int64_t pad_slot_id;
|
||||
int64_t null_block_id;
|
||||
|
||||
bool delta_softplus;
|
||||
bool cache_enabled;
|
||||
|
||||
@@ -118,9 +118,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
|
||||
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
int cache_index;
|
||||
if (cache_indices == nullptr) {
|
||||
cache_index = batch_id;
|
||||
} else if (params.cache_enabled) {
|
||||
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
|
||||
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
|
||||
} else {
|
||||
cache_index = cache_indices[batch_id];
|
||||
}
|
||||
// Skip batch entries whose cache index maps to the null block (padding).
|
||||
if (cache_indices != nullptr && cache_index == params.null_block_id){
|
||||
return;
|
||||
}
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
||||
@@ -527,7 +535,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool varlen,
|
||||
int64_t pad_slot_id,
|
||||
int64_t null_block_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
@@ -544,7 +552,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
params.null_block_id = null_block_id;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
@@ -658,7 +666,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id,
|
||||
int64_t null_block_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
@@ -805,7 +813,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
varlen,
|
||||
pad_slot_id,
|
||||
null_block_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
#include "gpt_oss_router_gemm.cuh"
|
||||
|
||||
void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB,
|
||||
__nv_bfloat16* gC, __nv_bfloat16* bias,
|
||||
int batch_size, int output_features,
|
||||
int input_features, cudaStream_t stream) {
|
||||
static int const WARP_TILE_M = 16;
|
||||
static int const TILE_M = WARP_TILE_M;
|
||||
static int const TILE_N = 8;
|
||||
static int const TILE_K = 64;
|
||||
static int const STAGES = 16;
|
||||
static int const STAGE_UNROLL = 4;
|
||||
static bool const PROFILE = false;
|
||||
|
||||
CUtensorMap weight_map{};
|
||||
CUtensorMap activation_map{};
|
||||
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features};
|
||||
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
|
||||
uint32_t box_size[rank] = {TILE_K, TILE_M};
|
||||
uint32_t elem_stride[rank] = {1, 1};
|
||||
|
||||
CUresult res = cuTensorMapEncodeTiled(
|
||||
&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank,
|
||||
gB, size, stride, box_size, elem_stride,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
TORCH_CHECK(res == CUDA_SUCCESS,
|
||||
"cuTensorMapEncodeTiled failed for weight_map, error code=",
|
||||
static_cast<int>(res));
|
||||
|
||||
size[1] = batch_size;
|
||||
box_size[1] = TILE_N;
|
||||
|
||||
res = cuTensorMapEncodeTiled(
|
||||
&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
rank, gA, size, stride, box_size, elem_stride,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
TORCH_CHECK(res == CUDA_SUCCESS,
|
||||
"cuTensorMapEncodeTiled failed for activation_map, error code=",
|
||||
static_cast<int>(res));
|
||||
|
||||
int smem_size = STAGES * STAGE_UNROLL *
|
||||
(TILE_M * TILE_K * sizeof(__nv_bfloat16) +
|
||||
TILE_N * TILE_K * sizeof(__nv_bfloat16));
|
||||
|
||||
gpuErrChk(cudaFuncSetAttribute(
|
||||
gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
|
||||
STAGE_UNROLL, PROFILE>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
|
||||
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
|
||||
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
|
||||
|
||||
dim3 grid(tiles_m, tiles_n);
|
||||
dim3 block(384);
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
config.gridDim = grid;
|
||||
config.blockDim = block;
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
config.attrs = attrs;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.numAttrs = 1;
|
||||
|
||||
cudaLaunchKernelEx(
|
||||
&config,
|
||||
&gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
|
||||
STAGE_UNROLL, PROFILE>,
|
||||
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map,
|
||||
activation_map, nullptr);
|
||||
}
|
||||
|
||||
void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output,
|
||||
torch::Tensor input, torch::Tensor weight,
|
||||
torch::Tensor bias) {
|
||||
auto const batch_size = input.size(0);
|
||||
auto const input_dim = input.size(1);
|
||||
auto const output_dim = weight.size(0);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::BFloat16) {
|
||||
launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(),
|
||||
(__nv_bfloat16*)weight.data_ptr(),
|
||||
(__nv_bfloat16*)output.mutable_data_ptr(),
|
||||
(__nv_bfloat16*)bias.data_ptr(), batch_size,
|
||||
output_dim, input_dim, stream);
|
||||
} else {
|
||||
throw std::invalid_argument("Unsupported dtype, only supports bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
|
||||
torch::Tensor weight, torch::Tensor bias) {
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D");
|
||||
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
|
||||
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");
|
||||
TORCH_CHECK(input.sizes()[1] == weight.sizes()[1],
|
||||
"input.size(1) must match weight.size(1)");
|
||||
TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0],
|
||||
"weight.size(0) must match bias.size(0)");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"input tensor must be bfloat16");
|
||||
TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16,
|
||||
"weight tensor must be bfloat16");
|
||||
TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16,
|
||||
"bias tensor must be bfloat16");
|
||||
gpt_oss_router_gemm_cuda_forward(output, input, weight, bias);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user