Compare commits
206 Commits
v0.16.1rc0
...
v0.17.0rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
097eb544e9 | ||
|
|
7cdba98edf | ||
|
|
3c85cd9d74 | ||
|
|
edba15045a | ||
|
|
e379396167 | ||
|
|
6e9f21e8a2 | ||
|
|
c1d963403c | ||
|
|
77e6dcbbfa | ||
|
|
70c73df69e | ||
|
|
9a9d442464 | ||
|
|
f7da9cdffc | ||
|
|
f22ff2958c | ||
|
|
d15c3b90fc | ||
|
|
97286a20ed | ||
|
|
12b38c0f45 | ||
|
|
467886a0c4 | ||
|
|
a9b8b13e5c | ||
|
|
e7213003cb | ||
|
|
3a8eef5869 | ||
|
|
97995f6376 | ||
|
|
881a6b011b | ||
|
|
8e1fd5baf0 | ||
|
|
ae88468bcc | ||
|
|
e05cb3b93e | ||
|
|
28ef9ba399 | ||
|
|
fb7fdc49c4 | ||
|
|
ea463978bb | ||
|
|
440f0e7dc6 | ||
|
|
fd4a90f337 | ||
|
|
ad9d09e2b8 | ||
|
|
4beebfd146 | ||
|
|
b8401cde0e | ||
|
|
5dfc5abe94 | ||
|
|
8fa68a8ce4 | ||
|
|
35a6f0bfe2 | ||
|
|
3a6cbf16e2 | ||
|
|
f44d1ddc8c | ||
|
|
48a54c1e0d | ||
|
|
8b9e8b7454 | ||
|
|
c21d0039ec | ||
|
|
7d8bbe6f42 | ||
|
|
25e02647c2 | ||
|
|
a0a5178ab4 | ||
|
|
8ea8ba275e | ||
|
|
4f85bae9d6 | ||
|
|
0a7165fd71 | ||
|
|
6521ccf286 | ||
|
|
8ebd872f50 | ||
|
|
168ee03e1c | ||
|
|
9dd656f0ea | ||
|
|
c8b678e53e | ||
|
|
18c29c746b | ||
|
|
96fc09503a | ||
|
|
1b82b433fc | ||
|
|
9319044ee9 | ||
|
|
c42dc402c1 | ||
|
|
fa6a6be519 | ||
|
|
cad21918e3 | ||
|
|
53700bf49b | ||
|
|
a13d8c03c9 | ||
|
|
9433acb8df | ||
|
|
d1a6e96d9e | ||
|
|
2a9e3347e9 | ||
|
|
cc0d565f40 | ||
|
|
358e4d5ba7 | ||
|
|
792a74b973 | ||
|
|
4034c3d32e | ||
|
|
7560d674c9 | ||
|
|
d9c7730877 | ||
|
|
ada4f4fadd | ||
|
|
7e9149d9a9 | ||
|
|
87c98b0236 | ||
|
|
de7dd634b9 | ||
|
|
9a87b0578f | ||
|
|
510bc9e1df | ||
|
|
cbd361fd46 | ||
|
|
c212202d93 | ||
|
|
ec27b36b4b | ||
|
|
3fd1d4ec2c | ||
|
|
cb21972a97 | ||
|
|
c34963f138 | ||
|
|
f26650d649 | ||
|
|
92f5d0f070 | ||
|
|
a60985b07e | ||
|
|
8b5014d3dd | ||
|
|
57a96e26c9 | ||
|
|
e82fbeec7b | ||
|
|
6290470843 | ||
|
|
72f4d16262 | ||
|
|
5a435507d8 | ||
|
|
59d7af9c6c | ||
|
|
bbf81f9a92 | ||
|
|
da543d1abe | ||
|
|
87d319c52f | ||
|
|
a9ec392c86 | ||
|
|
afd089f231 | ||
|
|
3ecd0bf9fc | ||
|
|
e3eb146f7a | ||
|
|
95a395dbec | ||
|
|
e94b263bd6 | ||
|
|
e113a30113 | ||
|
|
1dafb29f91 | ||
|
|
49b9ae32e9 | ||
|
|
63d7972f13 | ||
|
|
c68e69f144 | ||
|
|
7e08c22b8c | ||
|
|
8e75d88554 | ||
|
|
0892d1ab1f | ||
|
|
7600642eae | ||
|
|
1e69c04887 | ||
|
|
4292e3b807 | ||
|
|
24d6ea8afd | ||
|
|
57c86c0741 | ||
|
|
06254d4cbb | ||
|
|
f5d1281c9d | ||
|
|
94029ffaf0 | ||
|
|
88e8525f2e | ||
|
|
b2d8b422b2 | ||
|
|
1d5ab5d603 | ||
|
|
7b346ba8ed | ||
|
|
dea268336f | ||
|
|
90805ff464 | ||
|
|
2562e0271e | ||
|
|
fd68cd132b | ||
|
|
0edf101d2b | ||
|
|
d5b6f3ba36 | ||
|
|
1a014a0a93 | ||
|
|
86ac7bcf84 | ||
|
|
405f28d38d | ||
|
|
5323672bc2 | ||
|
|
a201ad72d8 | ||
|
|
e3691988d0 | ||
|
|
9fa6c68fa6 | ||
|
|
2ce6f3cf67 | ||
|
|
1f3dbd95fd | ||
|
|
1d532f9d8f | ||
|
|
234a65b781 | ||
|
|
2decec9856 | ||
|
|
29b35477b0 | ||
|
|
b1d9f5372d | ||
|
|
fd6de37fca | ||
|
|
c8aca0c9e1 | ||
|
|
b602e4f299 | ||
|
|
157722da75 | ||
|
|
1d897ff04f | ||
|
|
905d76b51d | ||
|
|
9098ce690c | ||
|
|
876312f0b5 | ||
|
|
5de98abc12 | ||
|
|
9251ed5c4f | ||
|
|
e8249378e4 | ||
|
|
6d4f9d3ad5 | ||
|
|
fbe3f0120a | ||
|
|
66c1751d13 | ||
|
|
6467b635b6 | ||
|
|
9c3fe9936b | ||
|
|
b66a74649e | ||
|
|
07bdabef03 | ||
|
|
a572baff5e | ||
|
|
516cf26698 | ||
|
|
487e5c51f7 | ||
|
|
1a8c71674e | ||
|
|
062b789632 | ||
|
|
a532c83849 | ||
|
|
1e5ad9b74f | ||
|
|
cabdaa7619 | ||
|
|
06be53563b | ||
|
|
c29ee9c326 | ||
|
|
d43048ce05 | ||
|
|
4fec53cfcb | ||
|
|
38c498b8e3 | ||
|
|
56a6371706 | ||
|
|
6283021142 | ||
|
|
01923eec70 | ||
|
|
31fb6f43da | ||
|
|
eb19955c37 | ||
|
|
0f2f24c8b2 | ||
|
|
d0105b84f0 | ||
|
|
832a780f3a | ||
|
|
98217b09f9 | ||
|
|
967572dd5f | ||
|
|
3d66502e1b | ||
|
|
c66aa48e99 | ||
|
|
b6d5a17298 | ||
|
|
5e58bdc711 | ||
|
|
a1f53addb1 | ||
|
|
05970c772c | ||
|
|
d940607629 | ||
|
|
99c7892c5b | ||
|
|
ec8f943db1 | ||
|
|
f2ad952f40 | ||
|
|
9e2cabdf9c | ||
|
|
ec8ab9d254 | ||
|
|
05972ea7e5 | ||
|
|
111d869069 | ||
|
|
7fea7250a4 | ||
|
|
845ee348ef | ||
|
|
ec13e549d3 | ||
|
|
c6ca51598a | ||
|
|
c0615a296d | ||
|
|
01914445b0 | ||
|
|
5281713e11 | ||
|
|
32693db8ce | ||
|
|
e03ddcfbd4 | ||
|
|
02acd16861 | ||
|
|
ab87f85231 |
@@ -51,5 +51,56 @@
|
|||||||
"max-model-len": 256,
|
"max-model-len": 256,
|
||||||
"async-scheduling": ""
|
"async-scheduling": ""
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "latency_deepseek_r1",
|
||||||
|
"environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"model": "deepseek-ai/DeepSeek-R1",
|
||||||
|
"tensor_parallel_size": 8,
|
||||||
|
"load_format": "dummy",
|
||||||
|
"max-model-len": 2048,
|
||||||
|
"dtype": "bfloat16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "latency_llama4_maverick_17b128e_instruct_fp8",
|
||||||
|
"environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||||
|
"tensor_parallel_size": 8,
|
||||||
|
"max-model-len": 512,
|
||||||
|
"max-num-seqs": 128,
|
||||||
|
"async-scheduling": "",
|
||||||
|
"gpu-memory-utilization": 0.95,
|
||||||
|
"enable_expert_parallel": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "latency_qwen3_8b",
|
||||||
|
"environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"model": "Qwen/Qwen3-8B",
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"max-model-len": 2048,
|
||||||
|
"max-num-seqs": 128,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"async-scheduling": ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -78,5 +78,84 @@
|
|||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "serving_deepseek_r1",
|
||||||
|
"qps_list": [1, 4, 16, "inf"],
|
||||||
|
"server_environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"server_parameters": {
|
||||||
|
"model": "deepseek-ai/DeepSeek-R1",
|
||||||
|
"tensor_parallel_size": 8,
|
||||||
|
"swap_space": 16,
|
||||||
|
"disable_log_stats": "",
|
||||||
|
"load_format": "dummy",
|
||||||
|
"max-model-len": 2048,
|
||||||
|
"max-num-seqs": 200,
|
||||||
|
"async-scheduling": "",
|
||||||
|
"dtype": "bfloat16"
|
||||||
|
},
|
||||||
|
"client_parameters": {
|
||||||
|
"model": "deepseek-ai/DeepSeek-R1",
|
||||||
|
"backend": "vllm",
|
||||||
|
"dataset_name": "sharegpt",
|
||||||
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"num_prompts": 200
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "serving_llama4_maverick_17b128e_instruct_fp8",
|
||||||
|
"qps_list": [1, 4, 16, "inf"],
|
||||||
|
"server_environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"server_parameters": {
|
||||||
|
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||||
|
"tensor_parallel_size": 8,
|
||||||
|
"disable_log_stats": "",
|
||||||
|
"max-model-len": 2048,
|
||||||
|
"max-num-seqs": 128,
|
||||||
|
"async-scheduling": "",
|
||||||
|
"enable_expert_parallel": "",
|
||||||
|
"max-num-batched-tokens": 4096
|
||||||
|
},
|
||||||
|
"client_parameters": {
|
||||||
|
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||||
|
"backend": "vllm",
|
||||||
|
"dataset_name": "sharegpt",
|
||||||
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"num_prompts": 200
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "serving_qwen3_8b",
|
||||||
|
"qps_list": [1, 4, 10, "inf"],
|
||||||
|
"server_environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"server_parameters": {
|
||||||
|
"model": "Qwen/Qwen-3-8B",
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"disable_log_stats": "",
|
||||||
|
"async-scheduling": ""
|
||||||
|
},
|
||||||
|
"client_parameters": {
|
||||||
|
"model": "Qwen/Qwen-3-8B",
|
||||||
|
"backend": "vllm",
|
||||||
|
"dataset_name": "sharegpt",
|
||||||
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"num_prompts": 200
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -57,5 +57,67 @@
|
|||||||
"max-num-seqs": 512,
|
"max-num-seqs": 512,
|
||||||
"async-scheduling": ""
|
"async-scheduling": ""
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "throughput_deepseek_r1",
|
||||||
|
"environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"model": "deepseek-ai/DeepSeek-R1",
|
||||||
|
"tensor_parallel_size": 8,
|
||||||
|
"load_format": "dummy",
|
||||||
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"dataset_name": "sharegpt",
|
||||||
|
"num_prompts": 1000,
|
||||||
|
"backend": "vllm",
|
||||||
|
"max-model-len": 2048,
|
||||||
|
"max-num-seqs": 384,
|
||||||
|
"async-scheduling": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "throughput_llama4_maverick_17b128e_instruct_fp8",
|
||||||
|
"environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||||
|
"tensor_parallel_size": 8,
|
||||||
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"dataset_name": "sharegpt",
|
||||||
|
"num_prompts": 1000,
|
||||||
|
"backend": "vllm",
|
||||||
|
"max-model-len": 2048,
|
||||||
|
"max-num-seqs": 512,
|
||||||
|
"async-scheduling": "",
|
||||||
|
"enable_expert_parallel": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"test_name": "throughput_qwen3_8b",
|
||||||
|
"environment_variables": {
|
||||||
|
"PT_HPU_LAZY_MODE": 1,
|
||||||
|
"PT_HPU_ENABLE_LAZY_COLLECTIVES": 1,
|
||||||
|
"VLLM_CONTIGUOUS_PA": 1,
|
||||||
|
"VLLM_DEFRAG": 1
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"model": "Qwen/Qwen-3-8B",
|
||||||
|
"tensor_parallel_size": 1,
|
||||||
|
"load_format": "dummy",
|
||||||
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"dataset_name": "sharegpt",
|
||||||
|
"num_prompts": 1000,
|
||||||
|
"max-num-seqs": 512,
|
||||||
|
"backend": "vllm",
|
||||||
|
"async-scheduling": ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/triton
|
|||||||
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/torchvision-*.whl .
|
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/torchvision-*.whl .
|
||||||
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/torchaudio-*.whl .
|
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/torchaudio-*.whl .
|
||||||
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/amdsmi-*.whl .
|
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/amdsmi-*.whl .
|
||||||
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/aiter-*.whl .
|
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/amd_aiter-*.whl .
|
||||||
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/flash-attn-*.whl .
|
aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/flash-attn-*.whl .
|
||||||
\`\`\`
|
\`\`\`
|
||||||
|
|
||||||
@@ -80,7 +80,7 @@ aws s3 cp s3://${S3_BUCKET}/rocm/${BUILDKITE_COMMIT}/${ROCM_VERSION_PATH}/flash-
|
|||||||
- **torchvision**: TorchVision for ROCm PyTorch
|
- **torchvision**: TorchVision for ROCm PyTorch
|
||||||
- **torchaudio**: Torchaudio for ROCm PyTorch
|
- **torchaudio**: Torchaudio for ROCm PyTorch
|
||||||
- **amdsmi**: AMD SMI Python bindings
|
- **amdsmi**: AMD SMI Python bindings
|
||||||
- **aiter**: Aiter for ROCm
|
- **amd_aiter**: Aiter for ROCm
|
||||||
- **flash-attn**: Flash Attention for ROCm
|
- **flash-attn**: Flash Attention for ROCm
|
||||||
|
|
||||||
### :warning: Notes
|
### :warning: Notes
|
||||||
|
|||||||
205
.buildkite/scripts/check-ray-compatibility.sh
Normal file
205
.buildkite/scripts/check-ray-compatibility.sh
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
#
|
||||||
|
# Check if Ray LLM can generate lock files that are compatible with this
|
||||||
|
# version of vllm. Downloads Ray's requirement files and runs a full
|
||||||
|
# dependency resolution with the installed vllm's constraints to see if
|
||||||
|
# a valid lock file can be produced.
|
||||||
|
#
|
||||||
|
# See: https://github.com/vllm-project/vllm/issues/33599
|
||||||
|
|
||||||
|
set -eo pipefail
|
||||||
|
|
||||||
|
RAY_BASE_URL="https://raw.githubusercontent.com/ray-project/ray/master/python"
|
||||||
|
|
||||||
|
WORK_DIR=$(mktemp -d)
|
||||||
|
trap 'rm -rf "$WORK_DIR"' EXIT
|
||||||
|
|
||||||
|
# Fetch all Ray requirement files used in the LLM depset pipeline
|
||||||
|
echo ">>> Fetching Ray requirement files"
|
||||||
|
RAY_FILES=(
|
||||||
|
"requirements.txt"
|
||||||
|
"requirements/cloud-requirements.txt"
|
||||||
|
"requirements/base-test-requirements.txt"
|
||||||
|
"requirements/llm/llm-requirements.txt"
|
||||||
|
"requirements/llm/llm-test-requirements.txt"
|
||||||
|
)
|
||||||
|
for FILE in "${RAY_FILES[@]}"; do
|
||||||
|
LOCAL_PATH="${WORK_DIR}/$(basename "$FILE")"
|
||||||
|
echo " ${FILE}"
|
||||||
|
curl -fsSL -o "$LOCAL_PATH" "${RAY_BASE_URL}/${FILE}"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Extract installed vllm deps
|
||||||
|
echo ">>> Extracting installed vllm dependency constraints"
|
||||||
|
python3 - "${WORK_DIR}/vllm-constraints.txt" <<'PYEOF'
|
||||||
|
"""Write out the installed vllm's dependencies as pip constraint lines.
|
||||||
|
|
||||||
|
Ray uses vllm[audio], so audio-extra deps are included with their extra
|
||||||
|
markers stripped. The resolver cannot evaluate extra markers for a
|
||||||
|
package that is not itself being resolved from an index, so we activate
|
||||||
|
them manually here.
|
||||||
|
"""
|
||||||
|
import importlib.metadata
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
out_path = sys.argv[1]
|
||||||
|
raw_reqs = importlib.metadata.requires("vllm") or []
|
||||||
|
|
||||||
|
# Ray uses vllm[audio] – activate that extra.
|
||||||
|
ACTIVE_EXTRAS = {"audio"}
|
||||||
|
EXTRA_RE = re.compile(r"""extra\s*==\s*['"]([^'"]+)['"]""")
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for r in raw_reqs:
|
||||||
|
if ";" not in r:
|
||||||
|
# Unconditional dep — always include.
|
||||||
|
lines.append(r.strip())
|
||||||
|
continue
|
||||||
|
|
||||||
|
req_part, _, marker_part = r.partition(";")
|
||||||
|
marker_part = marker_part.strip()
|
||||||
|
|
||||||
|
extra_matches = EXTRA_RE.findall(marker_part)
|
||||||
|
if not extra_matches:
|
||||||
|
# Non-extra marker (python_version, etc.) — keep as-is.
|
||||||
|
lines.append(r.strip())
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not ACTIVE_EXTRAS.intersection(extra_matches):
|
||||||
|
continue # Skip inactive extras (tensorizer, bench, …).
|
||||||
|
|
||||||
|
# Strip the extra== conditions but keep any remaining markers
|
||||||
|
# (e.g. python_version).
|
||||||
|
cleaned = EXTRA_RE.sub("", marker_part)
|
||||||
|
cleaned = re.sub(r"\band\b\s*\band\b", "and", cleaned)
|
||||||
|
cleaned = re.sub(r"^\s*and\s+|\s+and\s*$", "", cleaned).strip()
|
||||||
|
|
||||||
|
if cleaned:
|
||||||
|
lines.append(f"{req_part.strip()} ; {cleaned}")
|
||||||
|
else:
|
||||||
|
lines.append(req_part.strip())
|
||||||
|
|
||||||
|
with open(out_path, "w") as f:
|
||||||
|
for line in lines:
|
||||||
|
f.write(line + "\n")
|
||||||
|
|
||||||
|
print(f"Wrote {len(lines)} constraints to {out_path}")
|
||||||
|
PYEOF
|
||||||
|
|
||||||
|
echo ">>> Installed vllm deps (first 20 lines):"
|
||||||
|
head -20 "${WORK_DIR}/vllm-constraints.txt"
|
||||||
|
|
||||||
|
# Remove Ray's vllm pin — the installed vllm's transitive deps
|
||||||
|
# (written above) replace it in the resolution. vllm itself cannot
|
||||||
|
# be resolved from PyPI for in-development versions, so we test
|
||||||
|
# whether Ray's requirements can coexist with vllm's dependency
|
||||||
|
# constraints instead.
|
||||||
|
sed -i '/^vllm/d' "${WORK_DIR}/llm-requirements.txt"
|
||||||
|
|
||||||
|
# Install uv if needed
|
||||||
|
if ! command -v uv &>/dev/null; then
|
||||||
|
echo ">>> Installing uv"
|
||||||
|
pip install uv -q
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Resolve: given vllm's constraints, can Ray compile a lock file?
|
||||||
|
#
|
||||||
|
# vllm's dependency constraints are the fixed side — Ray is flexible and
|
||||||
|
# can regenerate its lock files. We pass vllm's constraints via -c so
|
||||||
|
# the resolver treats them as non-negotiable bounds, then check whether
|
||||||
|
# Ray's own requirements can still be satisfied within those bounds.
|
||||||
|
echo ""
|
||||||
|
echo "============================================================"
|
||||||
|
echo ">>> Resolving: Can Ray generate compatible lock files?"
|
||||||
|
echo "============================================================"
|
||||||
|
|
||||||
|
set +e
|
||||||
|
uv pip compile \
|
||||||
|
"${WORK_DIR}/requirements.txt" \
|
||||||
|
"${WORK_DIR}/cloud-requirements.txt" \
|
||||||
|
"${WORK_DIR}/base-test-requirements.txt" \
|
||||||
|
"${WORK_DIR}/llm-requirements.txt" \
|
||||||
|
"${WORK_DIR}/llm-test-requirements.txt" \
|
||||||
|
-c "${WORK_DIR}/vllm-constraints.txt" \
|
||||||
|
--python-version 3.12 \
|
||||||
|
--python-platform x86_64-manylinux_2_31 \
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cu129 \
|
||||||
|
--index-strategy unsafe-best-match \
|
||||||
|
--unsafe-package setuptools \
|
||||||
|
--unsafe-package ray \
|
||||||
|
--no-header \
|
||||||
|
-o "${WORK_DIR}/resolved.txt" \
|
||||||
|
2>&1
|
||||||
|
EXIT_CODE=$?
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
if [ $EXIT_CODE -eq 0 ]; then
|
||||||
|
echo "SUCCESS: Ray can generate lock files compatible with this vllm."
|
||||||
|
echo ""
|
||||||
|
echo "Key resolved versions:"
|
||||||
|
grep -E '^(protobuf|torch|numpy|transformers)==' \
|
||||||
|
"${WORK_DIR}/resolved.txt" | sort || true
|
||||||
|
echo "=========================================="
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "FAILURE: Ray cannot generate lock files compatible with this vllm."
|
||||||
|
echo "This means a fundamental dependency conflict exists that Ray"
|
||||||
|
echo "cannot resolve by regenerating its lock files."
|
||||||
|
echo "See: https://github.com/vllm-project/vllm/issues/33599"
|
||||||
|
echo "=========================================="
|
||||||
|
|
||||||
|
# Buildkite annotation
|
||||||
|
if [ -f /usr/bin/buildkite-agent ]; then
|
||||||
|
buildkite-agent annotate --style 'warning' --context 'ray-compat' << EOF
|
||||||
|
### :warning: Ray Dependency Compatibility Warning
|
||||||
|
This PR introduces dependencies that **cannot** be resolved with Ray's requirements.
|
||||||
|
Ray would not be able to regenerate its lock files to accommodate this vllm version.
|
||||||
|
|
||||||
|
Please check the **Ray Dependency Compatibility Check** step logs for details.
|
||||||
|
See [issue #33599](https://github.com/vllm-project/vllm/issues/33599) for context.
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Notify Slack if webhook is configured.
|
||||||
|
if [ -n "$RAY_COMPAT_SLACK_WEBHOOK_URL" ]; then
|
||||||
|
echo ">>> Sending Slack notification"
|
||||||
|
# Single quotes are intentional: the f-string expressions are Python, not shell.
|
||||||
|
# shellcheck disable=SC2016
|
||||||
|
PAYLOAD=$(python3 -c '
|
||||||
|
import json, os, sys
|
||||||
|
pr = os.getenv("BUILDKITE_PULL_REQUEST", "N/A")
|
||||||
|
branch = os.getenv("BUILDKITE_BRANCH", "unknown")
|
||||||
|
url = os.getenv("BUILDKITE_BUILD_URL", "#")
|
||||||
|
data = {
|
||||||
|
"text": ":warning: Ray Dependency Compatibility Check Failed",
|
||||||
|
"blocks": [{
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": (
|
||||||
|
"*:warning: Ray Dependency Compatibility Check Failed*\n"
|
||||||
|
f"PR #{pr} on branch `{branch}` introduces dependencies "
|
||||||
|
f"that cannot be resolved with Ray'\''s requirements.\n"
|
||||||
|
f"<{url}|View Build>"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
print(json.dumps(data))
|
||||||
|
')
|
||||||
|
|
||||||
|
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$RAY_COMPAT_SLACK_WEBHOOK_URL" \
|
||||||
|
-H 'Content-type: application/json' \
|
||||||
|
-d "$PAYLOAD")
|
||||||
|
echo " Slack webhook response: $HTTP_CODE"
|
||||||
|
else
|
||||||
|
echo ">>> Skipping Slack notification (RAY_COMPAT_SLACK_WEBHOOK_URL not set)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 1
|
||||||
@@ -6,6 +6,26 @@
|
|||||||
# Multi-node detection: Instead of matching on fragile group names, we detect
|
# Multi-node detection: Instead of matching on fragile group names, we detect
|
||||||
# multi-node jobs structurally by looking for the bracket command syntax
|
# multi-node jobs structurally by looking for the bracket command syntax
|
||||||
# "[node0_cmds] && [node1_cmds]" or via the NUM_NODES environment variable.
|
# "[node0_cmds] && [node1_cmds]" or via the NUM_NODES environment variable.
|
||||||
|
#
|
||||||
|
###############################################################################
|
||||||
|
# QUOTING / COMMAND PASSING
|
||||||
|
#
|
||||||
|
# Passing commands as positional arguments ($*) is fragile when the command
|
||||||
|
# string itself contains double quotes, e.g.:
|
||||||
|
#
|
||||||
|
# bash run-amd-test.sh "export FLAGS="value" && pytest -m "not slow""
|
||||||
|
#
|
||||||
|
# The outer shell resolves the nested quotes *before* this script runs, so
|
||||||
|
# the script receives mangled input it cannot fully recover.
|
||||||
|
#
|
||||||
|
# Preferred: pass commands via the VLLM_TEST_COMMANDS environment variable:
|
||||||
|
#
|
||||||
|
# export VLLM_TEST_COMMANDS='export FLAGS="value" && pytest -m "not slow"'
|
||||||
|
# bash run-amd-test.sh
|
||||||
|
#
|
||||||
|
# Single-quoted assignment preserves all inner double quotes verbatim.
|
||||||
|
# The $* path is kept for backward compatibility but callers should migrate.
|
||||||
|
###############################################################################
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
|
|
||||||
# Export Python path
|
# Export Python path
|
||||||
@@ -79,26 +99,157 @@ is_multi_node() {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handle_pytest_exit() {
|
||||||
|
local exit_code=$1
|
||||||
|
if [ "$exit_code" -eq 5 ]; then
|
||||||
|
echo "Pytest exit code 5 (no tests collected) - treating as success."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
exit "$exit_code"
|
||||||
|
}
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Pytest marker re-quoting
|
# Pytest marker/keyword re-quoting
|
||||||
#
|
#
|
||||||
# When commands are passed through Buildkite -> shell -> $* -> bash -c,
|
# When commands are passed through Buildkite -> shell -> $* -> bash -c,
|
||||||
# quotes around pytest -m marker expressions get stripped:
|
# quotes around multi-word pytest -m/-k expressions get stripped:
|
||||||
# pytest -v -s -m 'not cpu_test' v1/core
|
# pytest -v -s -m 'not cpu_test' v1/core
|
||||||
# becomes:
|
# becomes:
|
||||||
# pytest -v -s -m not cpu_test v1/core
|
# pytest -v -s -m not cpu_test v1/core
|
||||||
#
|
#
|
||||||
# pytest then interprets "cpu_test" as a file path, not part of the marker.
|
# pytest then interprets "cpu_test" as a file path, not part of the marker.
|
||||||
# This function detects unquoted multi-word marker expressions and re-quotes
|
#
|
||||||
# them so they survive the final bash -c expansion.
|
# This function detects unquoted expressions after -m/-k and re-quotes them
|
||||||
|
# by collecting tokens until a recognizable boundary is reached:
|
||||||
|
# - test path (contains '/')
|
||||||
|
# - test file (ends with '.py')
|
||||||
|
# - another pytest flag (--xxx or -x single-char flags)
|
||||||
|
# - command separator (&& || ; |)
|
||||||
|
# - environment variable assignment (FOO=bar)
|
||||||
|
#
|
||||||
|
# Single-word markers (e.g. -m cpu_test, -m hybrid_model) pass through
|
||||||
|
# unquoted since they have no spaces and work fine.
|
||||||
|
#
|
||||||
|
# Already-quoted expressions (containing literal single quotes) are passed
|
||||||
|
# through untouched to avoid double-quoting values injected by
|
||||||
|
# apply_rocm_test_overrides.
|
||||||
|
#
|
||||||
|
# NOTE: This ONLY fixes -m/-k flags. It cannot recover arbitrary inner
|
||||||
|
# double-quotes stripped by the calling shell (see header comment).
|
||||||
|
# Use VLLM_TEST_COMMANDS to avoid the problem entirely.
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
re_quote_pytest_markers() {
|
re_quote_pytest_markers() {
|
||||||
local cmds="$1"
|
local input="$1"
|
||||||
# Pattern: -m not <identifier> -> -m 'not <identifier>'
|
local output=""
|
||||||
# Handles the common cases: 'not cpu_test', 'not slow_test', etc.
|
local collecting=false
|
||||||
cmds=$(echo "$cmds" | sed -E "s/-m not ([a-zA-Z_][a-zA-Z0-9_]*)/-m 'not \1'/g")
|
local marker_buf=""
|
||||||
echo "$cmds"
|
|
||||||
|
# Strip backslash-newline continuations, then flatten remaining newlines
|
||||||
|
local flat="${input//$'\\\n'/ }"
|
||||||
|
flat="${flat//$'\n'/ }"
|
||||||
|
|
||||||
|
# Disable globbing to prevent *.py etc. from expanding during read -ra
|
||||||
|
local restore_glob
|
||||||
|
restore_glob="$(shopt -p -o noglob 2>/dev/null || true)"
|
||||||
|
set -o noglob
|
||||||
|
local -a words
|
||||||
|
read -ra words <<< "$flat"
|
||||||
|
eval "$restore_glob"
|
||||||
|
|
||||||
|
for word in "${words[@]}"; do
|
||||||
|
if $collecting; then
|
||||||
|
# If the token we're about to collect already contains a literal
|
||||||
|
# single quote, the expression was already quoted upstream.
|
||||||
|
# Flush and stop collecting.
|
||||||
|
if [[ "$word" == *"'"* ]]; then
|
||||||
|
if [[ -n "$marker_buf" ]]; then
|
||||||
|
# Should not normally happen (partial buf + quote), flush raw
|
||||||
|
output+="${marker_buf} "
|
||||||
|
marker_buf=""
|
||||||
|
fi
|
||||||
|
output+="${word} "
|
||||||
|
collecting=false
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
local is_boundary=false
|
||||||
|
case "$word" in
|
||||||
|
# Line-continuation artifact
|
||||||
|
"\\")
|
||||||
|
is_boundary=true ;;
|
||||||
|
# Command separators
|
||||||
|
"&&"|"||"|";"|"|")
|
||||||
|
is_boundary=true ;;
|
||||||
|
# Long flags (--ignore, --shard-id, etc.)
|
||||||
|
--*)
|
||||||
|
is_boundary=true ;;
|
||||||
|
# Short flags (-v, -s, -x, etc.) but NOT negative marker tokens
|
||||||
|
# like "not" which don't start with "-". Also skip -k/-m which
|
||||||
|
# would start a new marker (handled below).
|
||||||
|
-[a-zA-Z])
|
||||||
|
is_boundary=true ;;
|
||||||
|
# Test path (contains /)
|
||||||
|
*/*)
|
||||||
|
is_boundary=true ;;
|
||||||
|
# Test file (ends with .py, possibly with ::method)
|
||||||
|
*.py|*.py::*)
|
||||||
|
is_boundary=true ;;
|
||||||
|
# Environment variable assignment preceding a command (FOO=bar)
|
||||||
|
*=*)
|
||||||
|
# Only treat as boundary if it looks like VAR=value, not
|
||||||
|
# pytest filter expressions like num_gpus=2 inside markers
|
||||||
|
if [[ "$word" =~ ^[A-Z_][A-Z0-9_]*= ]]; then
|
||||||
|
is_boundary=true
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
if $is_boundary; then
|
||||||
|
# Flush the collected marker expression
|
||||||
|
if [[ "$marker_buf" == *" "* || "$marker_buf" == *"("* ]]; then
|
||||||
|
output+="'${marker_buf}' "
|
||||||
|
else
|
||||||
|
output+="${marker_buf} "
|
||||||
|
fi
|
||||||
|
collecting=false
|
||||||
|
marker_buf=""
|
||||||
|
# Check if this boundary word itself starts a new -m/-k
|
||||||
|
if [[ "$word" == "-m" || "$word" == "-k" ]]; then
|
||||||
|
output+="${word} "
|
||||||
|
collecting=true
|
||||||
|
# Drop stray backslash tokens silently
|
||||||
|
elif [[ "$word" == "\\" ]]; then
|
||||||
|
:
|
||||||
|
else
|
||||||
|
output+="${word} "
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
# Accumulate into marker buffer
|
||||||
|
if [[ -n "$marker_buf" ]]; then
|
||||||
|
marker_buf+=" ${word}"
|
||||||
|
else
|
||||||
|
marker_buf="${word}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
elif [[ "$word" == "-m" || "$word" == "-k" ]]; then
|
||||||
|
output+="${word} "
|
||||||
|
collecting=true
|
||||||
|
marker_buf=""
|
||||||
|
else
|
||||||
|
output+="${word} "
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Flush any trailing marker expression (marker at end of command)
|
||||||
|
if $collecting && [[ -n "$marker_buf" ]]; then
|
||||||
|
if [[ "$marker_buf" == *" "* || "$marker_buf" == *"("* ]]; then
|
||||||
|
output+="'${marker_buf}'"
|
||||||
|
else
|
||||||
|
output+="${marker_buf}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "${output% }"
|
||||||
}
|
}
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
@@ -231,11 +382,35 @@ HF_CACHE="$(realpath ~)/huggingface"
|
|||||||
mkdir -p "${HF_CACHE}"
|
mkdir -p "${HF_CACHE}"
|
||||||
HF_MOUNT="/root/.cache/huggingface"
|
HF_MOUNT="/root/.cache/huggingface"
|
||||||
|
|
||||||
commands="$*"
|
# ---- Command source selection ----
|
||||||
|
# Prefer VLLM_TEST_COMMANDS (preserves all inner quoting intact).
|
||||||
|
# Fall back to $* for backward compatibility, but warn that inner
|
||||||
|
# double-quotes will have been stripped by the calling shell.
|
||||||
|
if [[ -n "${VLLM_TEST_COMMANDS:-}" ]]; then
|
||||||
|
commands="${VLLM_TEST_COMMANDS}"
|
||||||
|
echo "Commands sourced from VLLM_TEST_COMMANDS (quoting preserved)"
|
||||||
|
else
|
||||||
|
commands="$*"
|
||||||
|
if [[ -z "$commands" ]]; then
|
||||||
|
echo "Error: No test commands provided." >&2
|
||||||
|
echo "Usage:" >&2
|
||||||
|
echo " Preferred: VLLM_TEST_COMMANDS='...' bash $0" >&2
|
||||||
|
echo " Legacy: bash $0 \"commands here\"" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Commands sourced from positional args (legacy mode)"
|
||||||
|
echo "WARNING: Inner double-quotes in the command string may have been"
|
||||||
|
echo " stripped by the calling shell. If you see syntax errors, switch to:"
|
||||||
|
echo " export VLLM_TEST_COMMANDS='your commands here'"
|
||||||
|
echo " bash $0"
|
||||||
|
fi
|
||||||
|
|
||||||
echo "Raw commands: $commands"
|
echo "Raw commands: $commands"
|
||||||
|
|
||||||
# Fix quoting before ROCm overrides (so overrides see correct structure)
|
# Fix quoting before ROCm overrides (so overrides see correct structure)
|
||||||
commands=$(re_quote_pytest_markers "$commands")
|
commands=$(re_quote_pytest_markers "$commands")
|
||||||
|
echo "After re-quoting: $commands"
|
||||||
|
|
||||||
commands=$(apply_rocm_test_overrides "$commands")
|
commands=$(apply_rocm_test_overrides "$commands")
|
||||||
echo "Final commands: $commands"
|
echo "Final commands: $commands"
|
||||||
|
|
||||||
@@ -248,6 +423,18 @@ if [[ -z "$render_gid" ]]; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# --- RDMA device passthrough (conditional) ---
|
||||||
|
# If the host has RDMA devices, pass them through so tests like
|
||||||
|
# test_moriio_connector can access ibverbs. On hosts without RDMA
|
||||||
|
# hardware the tests will gracefully skip via _rdma_available().
|
||||||
|
RDMA_FLAGS=""
|
||||||
|
if [ -d /dev/infiniband ]; then
|
||||||
|
echo "RDMA devices detected on host, enabling passthrough"
|
||||||
|
RDMA_FLAGS="--device /dev/infiniband --cap-add=IPC_LOCK"
|
||||||
|
else
|
||||||
|
echo "No RDMA devices found on host, RDMA tests will be skipped"
|
||||||
|
fi
|
||||||
|
|
||||||
# --- Route: multi-node vs single-node ---
|
# --- Route: multi-node vs single-node ---
|
||||||
if is_multi_node "$commands"; then
|
if is_multi_node "$commands"; then
|
||||||
echo "--- Multi-node job detected"
|
echo "--- Multi-node job detected"
|
||||||
@@ -282,7 +469,9 @@ if is_multi_node "$commands"; then
|
|||||||
done
|
done
|
||||||
|
|
||||||
/bin/bash -c "${composite_command}"
|
/bin/bash -c "${composite_command}"
|
||||||
|
exit_code=$?
|
||||||
cleanup_network
|
cleanup_network
|
||||||
|
handle_pytest_exit "$exit_code"
|
||||||
else
|
else
|
||||||
echo "Multi-node job detected but failed to parse bracket command syntax."
|
echo "Multi-node job detected but failed to parse bracket command syntax."
|
||||||
echo "Expected format: prefix ; [node0_cmd1, node0_cmd2] && [node1_cmd1, node1_cmd2]"
|
echo "Expected format: prefix ; [node0_cmd1, node0_cmd2] && [node1_cmd1, node1_cmd2]"
|
||||||
@@ -295,6 +484,7 @@ else
|
|||||||
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
|
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
|
||||||
docker run \
|
docker run \
|
||||||
--device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \
|
--device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \
|
||||||
|
$RDMA_FLAGS \
|
||||||
--network=host \
|
--network=host \
|
||||||
--shm-size=16gb \
|
--shm-size=16gb \
|
||||||
--group-add "$render_gid" \
|
--group-add "$render_gid" \
|
||||||
@@ -308,4 +498,7 @@ else
|
|||||||
--name "${container_name}" \
|
--name "${container_name}" \
|
||||||
"${image_name}" \
|
"${image_name}" \
|
||||||
/bin/bash -c "${commands}"
|
/bin/bash -c "${commands}"
|
||||||
|
|
||||||
|
exit_code=$?
|
||||||
|
handle_pytest_exit "$exit_code"
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -1,26 +1,43 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -euox pipefail
|
set -euox pipefail
|
||||||
|
export VLLM_CPU_CI_ENV=0
|
||||||
|
|
||||||
echo "--- PP+TP"
|
echo "--- PP+TP"
|
||||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
|
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
|
||||||
server_pid=$!
|
server_pid=$!
|
||||||
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
|
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||||
vllm bench serve \
|
vllm bench serve \
|
||||||
--backend vllm \
|
--backend vllm \
|
||||||
--dataset-name random \
|
--dataset-name random \
|
||||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||||
--num-prompts 20 \
|
--num-prompts 20 \
|
||||||
|
--result-dir ./test_results \
|
||||||
|
--result-filename tp_pp.json \
|
||||||
|
--save-result \
|
||||||
--endpoint /v1/completions
|
--endpoint /v1/completions
|
||||||
kill -s SIGTERM $server_pid &
|
kill -s SIGTERM $server_pid; wait $server_pid || true
|
||||||
|
failed_req=$(jq '.failed' ./test_results/tp_pp.json)
|
||||||
|
if [ "$failed_req" -ne 0 ]; then
|
||||||
|
echo "Some requests were failed!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
echo "--- DP+TP"
|
echo "--- DP+TP"
|
||||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
|
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
|
||||||
server_pid=$!
|
server_pid=$!
|
||||||
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
|
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||||
vllm bench serve \
|
vllm bench serve \
|
||||||
--backend vllm \
|
--backend vllm \
|
||||||
--dataset-name random \
|
--dataset-name random \
|
||||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||||
--num-prompts 20 \
|
--num-prompts 20 \
|
||||||
|
--result-dir ./test_results \
|
||||||
|
--result-filename dp_pp.json \
|
||||||
|
--save-result \
|
||||||
--endpoint /v1/completions
|
--endpoint /v1/completions
|
||||||
kill -s SIGTERM $server_pid &
|
kill -s SIGTERM $server_pid; wait $server_pid || true
|
||||||
|
failed_req=$(jq '.failed' ./test_results/dp_pp.json)
|
||||||
|
if [ "$failed_req" -ne 0 ]; then
|
||||||
|
echo "Some requests were failed!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|||||||
@@ -1,9 +1,27 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# This script build the CPU docker image and run the offline inference inside the container.
|
# This script builds the HPU docker image and runs the offline inference inside the container.
|
||||||
# It serves a sanity check for compilation and basic model usage.
|
# It serves a sanity check for compilation and basic model usage.
|
||||||
|
#
|
||||||
|
# vllm-gaudi compatibility pinning:
|
||||||
|
# The vllm-gaudi plugin is installed on top of the vllm upstream checkout used by this CI job.
|
||||||
|
# When upstream vllm changes its API, the plugin may break before it has been updated.
|
||||||
|
# To handle this, the vllm-gaudi repository maintains a file:
|
||||||
|
# vllm/last-good-commit-for-vllm-gaudi/VLLM_COMMUNITY_COMMIT
|
||||||
|
# The first line of that file controls what version of vllm is used inside the Docker image:
|
||||||
|
# - "latest" : no checkout override; the current Buildkite CI commit is used as-is.
|
||||||
|
# - "<commit SHA>" : vllm is checked out to that specific commit before building, pinning
|
||||||
|
# the test to a known-compatible baseline.
|
||||||
|
# To unpin (resume testing against the live vllm tip), set the file content back to "latest".
|
||||||
set -exuo pipefail
|
set -exuo pipefail
|
||||||
|
|
||||||
|
# Fetch the vllm community commit reference from vllm-gaudi (first line only).
|
||||||
|
VLLM_COMMUNITY_COMMIT=$(curl -s \
|
||||||
|
https://raw.githubusercontent.com/vllm-project/vllm-gaudi/vllm/last-good-commit-for-vllm-gaudi/VLLM_COMMUNITY_COMMIT \
|
||||||
|
| head -1 | tr -d '\n')
|
||||||
|
|
||||||
|
echo "Using vllm community commit: ${VLLM_COMMUNITY_COMMIT}"
|
||||||
|
|
||||||
# Try building the docker image
|
# Try building the docker image
|
||||||
image_name="hpu/upstream-vllm-ci:${BUILDKITE_COMMIT}"
|
image_name="hpu/upstream-vllm-ci:${BUILDKITE_COMMIT}"
|
||||||
container_name="hpu-upstream-vllm-ci-${BUILDKITE_COMMIT}-container"
|
container_name="hpu-upstream-vllm-ci-${BUILDKITE_COMMIT}-container"
|
||||||
@@ -12,6 +30,13 @@ FROM gaudi-base-image:latest
|
|||||||
|
|
||||||
COPY ./ /workspace/vllm
|
COPY ./ /workspace/vllm
|
||||||
|
|
||||||
|
# If VLLM_COMMUNITY_COMMIT is a specific commit (not "latest"), check it out to pin vllm
|
||||||
|
# to the version known to be compatible with vllm-gaudi. When the value is "latest",
|
||||||
|
# the current checkout (the Buildkite CI commit) is used unchanged.
|
||||||
|
RUN if [ "${VLLM_COMMUNITY_COMMIT}" != "latest" ]; then \
|
||||||
|
cd /workspace/vllm && git fetch --unshallow 2>/dev/null || true && git checkout ${VLLM_COMMUNITY_COMMIT}; \
|
||||||
|
fi
|
||||||
|
|
||||||
WORKDIR /workspace/vllm
|
WORKDIR /workspace/vllm
|
||||||
|
|
||||||
ENV no_proxy=localhost,127.0.0.1
|
ENV no_proxy=localhost,127.0.0.1
|
||||||
|
|||||||
@@ -156,8 +156,9 @@ steps:
|
|||||||
|
|
||||||
- label: Entrypoints Integration Test (API Server 1) # 100min
|
- label: Entrypoints Integration Test (API Server 1) # 100min
|
||||||
timeout_in_minutes: 130
|
timeout_in_minutes: 130
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
fast_check: true
|
fast_check: true
|
||||||
@@ -173,8 +174,9 @@ steps:
|
|||||||
|
|
||||||
- label: Entrypoints Integration Test (API Server 2)
|
- label: Entrypoints Integration Test (API Server 2)
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
fast_check: true
|
fast_check: true
|
||||||
@@ -192,8 +194,9 @@ steps:
|
|||||||
|
|
||||||
- label: Entrypoints Integration Test (Pooling)
|
- label: Entrypoints Integration Test (Pooling)
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
fast_check: true
|
fast_check: true
|
||||||
@@ -207,8 +210,9 @@ steps:
|
|||||||
|
|
||||||
- label: Entrypoints Integration Test (Responses API)
|
- label: Entrypoints Integration Test (Responses API)
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
fast_check: true
|
fast_check: true
|
||||||
@@ -222,8 +226,9 @@ steps:
|
|||||||
|
|
||||||
- label: Distributed Tests (4 GPUs) # 35min
|
- label: Distributed Tests (4 GPUs) # 35min
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_4
|
agent_pool: mi325_4
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
@@ -278,14 +283,16 @@ steps:
|
|||||||
- popd
|
- popd
|
||||||
# NEW rlhf examples
|
# NEW rlhf examples
|
||||||
- pushd ../examples/offline_inference/new_weight_syncing
|
- pushd ../examples/offline_inference/new_weight_syncing
|
||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
|
||||||
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
|
||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
|
||||||
- popd
|
- popd
|
||||||
|
|
||||||
- label: Distributed Tests (8 GPUs) # 4min
|
- label: Distributed Tests (8 GPUs) # 4min
|
||||||
timeout_in_minutes: 10
|
timeout_in_minutes: 10
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_8
|
agent_pool: mi325_8
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
gpu: h100
|
gpu: h100
|
||||||
num_gpus: 8
|
num_gpus: 8
|
||||||
@@ -380,10 +387,9 @@ steps:
|
|||||||
|
|
||||||
- label: V1 Test e2e + engine # 65min
|
- label: V1 Test e2e + engine # 65min
|
||||||
timeout_in_minutes: 90
|
timeout_in_minutes: 90
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
# The test uses 4 GPUs, but we schedule it on 8-GPU machines for stability.
|
agent_pool: mi325_1
|
||||||
# See discussion here: https://github.com/vllm-project/vllm/pull/31040
|
optional: true
|
||||||
agent_pool: mi325_8
|
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@@ -394,6 +400,34 @@ steps:
|
|||||||
- pytest -v -s v1/e2e
|
- pytest -v -s v1/e2e
|
||||||
- pytest -v -s v1/engine
|
- pytest -v -s v1/engine
|
||||||
|
|
||||||
|
- label: V1 Test e2e (2 GPUs) # 65min
|
||||||
|
timeout_in_minutes: 90
|
||||||
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
|
agent_pool: mi325_2
|
||||||
|
optional: true
|
||||||
|
# grade: Blocking
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/v1
|
||||||
|
commands:
|
||||||
|
# Only run tests that need exactly 2 GPUs
|
||||||
|
- pytest -v -s v1/e2e/test_spec_decode.py -k "tensor_parallelism"
|
||||||
|
|
||||||
|
- label: V1 Test e2e (4 GPUs) # 65min
|
||||||
|
timeout_in_minutes: 90
|
||||||
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
|
# The test uses 4 GPUs, but we schedule it on 8-GPU machines for stability.
|
||||||
|
# See discussion here: https://github.com/vllm-project/vllm/pull/31040
|
||||||
|
agent_pool: mi325_4
|
||||||
|
optional: true
|
||||||
|
# grade: Blocking
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/v1
|
||||||
|
commands:
|
||||||
|
# Only run tests that need 4 GPUs
|
||||||
|
- pytest -v -s v1/e2e/test_spec_decode.py -k "eagle_correctness_heavy"
|
||||||
|
|
||||||
- label: V1 Test entrypoints # 35min
|
- label: V1 Test entrypoints # 35min
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
|
mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
|
||||||
@@ -407,8 +441,9 @@ steps:
|
|||||||
|
|
||||||
- label: V1 Test others # 42min
|
- label: V1 Test others # 42min
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@@ -435,8 +470,9 @@ steps:
|
|||||||
# TODO: Add the "V1 Test attetion (MI300)" test group
|
# TODO: Add the "V1 Test attetion (MI300)" test group
|
||||||
|
|
||||||
- label: V1 Test attention (H100) # 10min
|
- label: V1 Test attention (H100) # 10min
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
gpu: h100
|
gpu: h100
|
||||||
@@ -540,8 +576,9 @@ steps:
|
|||||||
|
|
||||||
- label: Samplers Test # 56min
|
- label: Samplers Test # 56min
|
||||||
timeout_in_minutes: 75
|
timeout_in_minutes: 75
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/layers
|
- vllm/model_executor/layers
|
||||||
@@ -553,8 +590,9 @@ steps:
|
|||||||
|
|
||||||
- label: LoRA Test %N # 20min each
|
- label: LoRA Test %N # 20min each
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
@@ -572,6 +610,8 @@ steps:
|
|||||||
--ignore=lora/test_qwen3moe_tp.py
|
--ignore=lora/test_qwen3moe_tp.py
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
|
##### .buildkite/test_areas/pytorch.yaml #####
|
||||||
|
# corresponds to .buildkite/test_areas/pytorch.yaml
|
||||||
- label: PyTorch Compilation Unit Tests # 15min
|
- label: PyTorch Compilation Unit Tests # 15min
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
@@ -589,6 +629,20 @@ steps:
|
|||||||
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
|
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
|
||||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
|
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
|
||||||
|
|
||||||
|
# corresponds to .buildkite/test_areas/pytorch.yaml
|
||||||
|
- label: PyTorch Compilation Passes Unit Tests
|
||||||
|
timeout_in_minutes: 20
|
||||||
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
|
agent_pool: mi325_1
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/compile/passes
|
||||||
|
commands:
|
||||||
|
# TODO: clean up this comment if not needed. It is used to
|
||||||
|
# keep track of the tests changes during vLLM IR Ops refactoring.
|
||||||
|
# Use `find` to launch multiple instances of pytest.
|
||||||
|
- "find compile/passes -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
@@ -664,8 +718,9 @@ steps:
|
|||||||
|
|
||||||
- label: Kernels Quantization Test %N # 64min
|
- label: Kernels Quantization Test %N # 64min
|
||||||
timeout_in_minutes: 90
|
timeout_in_minutes: 90
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/quantization/
|
- csrc/quantization/
|
||||||
@@ -798,8 +853,9 @@ steps:
|
|||||||
|
|
||||||
- label: LM Eval Small Models # 53min
|
- label: LM Eval Small Models # 53min
|
||||||
timeout_in_minutes: 75
|
timeout_in_minutes: 75
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
@@ -860,8 +916,9 @@ steps:
|
|||||||
|
|
||||||
- label: Basic Models Tests (Other)
|
- label: Basic Models Tests (Other)
|
||||||
timeout_in_minutes: 45
|
timeout_in_minutes: 45
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@@ -902,8 +959,9 @@ steps:
|
|||||||
|
|
||||||
- label: Language Models Tests (Extra Standard) %N
|
- label: Language Models Tests (Extra Standard) %N
|
||||||
timeout_in_minutes: 45
|
timeout_in_minutes: 45
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@@ -923,8 +981,9 @@ steps:
|
|||||||
|
|
||||||
- label: Language Models Tests (Hybrid) %N
|
- label: Language Models Tests (Hybrid) %N
|
||||||
timeout_in_minutes: 75
|
timeout_in_minutes: 75
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@@ -944,7 +1003,7 @@ steps:
|
|||||||
|
|
||||||
- label: Language Models Test (Extended Generation) # 80min
|
- label: Language Models Test (Extended Generation) # 80min
|
||||||
timeout_in_minutes: 110
|
timeout_in_minutes: 110
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
optional: true
|
optional: true
|
||||||
@@ -960,7 +1019,7 @@ steps:
|
|||||||
|
|
||||||
- label: Language Models Test (PPL)
|
- label: Language Models Test (PPL)
|
||||||
timeout_in_minutes: 110
|
timeout_in_minutes: 110
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
optional: true
|
optional: true
|
||||||
@@ -972,7 +1031,7 @@ steps:
|
|||||||
|
|
||||||
- label: Language Models Test (Extended Pooling) # 36min
|
- label: Language Models Test (Extended Pooling) # 36min
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
optional: true
|
optional: true
|
||||||
@@ -984,7 +1043,7 @@ steps:
|
|||||||
|
|
||||||
- label: Language Models Test (MTEB)
|
- label: Language Models Test (MTEB)
|
||||||
timeout_in_minutes: 110
|
timeout_in_minutes: 110
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
optional: true
|
optional: true
|
||||||
@@ -996,11 +1055,12 @@ steps:
|
|||||||
|
|
||||||
- label: Multi-Modal Processor Test (CPU)
|
- label: Multi-Modal Processor Test (CPU)
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/multimodal
|
- tests/models/multimodal
|
||||||
|
- tests/models/registry.py
|
||||||
no_gpu: true
|
no_gpu: true
|
||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
@@ -1008,19 +1068,20 @@ steps:
|
|||||||
|
|
||||||
- label: Multi-Modal Processor Test # 44min
|
- label: Multi-Modal Processor Test # 44min
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/multimodal
|
- tests/models/multimodal
|
||||||
|
- tests/models/registry.py
|
||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pytest -v -s models/multimodal/processing
|
- pytest -v -s models/multimodal/processing
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Standard) # 60min
|
- label: Multi-Modal Models Test (Standard) # 60min
|
||||||
timeout_in_minutes: 100
|
timeout_in_minutes: 100
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
@@ -1053,7 +1114,7 @@ steps:
|
|||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 1 # 60min
|
- label: Multi-Modal Models Test (Extended) 1 # 60min
|
||||||
timeout_in_minutes: 120
|
timeout_in_minutes: 120
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
optional: true
|
optional: true
|
||||||
@@ -1068,7 +1129,7 @@ steps:
|
|||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 2 #60min
|
- label: Multi-Modal Models Test (Extended) 2 #60min
|
||||||
timeout_in_minutes: 120
|
timeout_in_minutes: 120
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
optional: true
|
optional: true
|
||||||
@@ -1083,7 +1144,7 @@ steps:
|
|||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 3 # 75min
|
- label: Multi-Modal Models Test (Extended) 3 # 75min
|
||||||
timeout_in_minutes: 150
|
timeout_in_minutes: 150
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
optional: true
|
optional: true
|
||||||
@@ -1108,7 +1169,7 @@ steps:
|
|||||||
- pytest -v -s models/quantization
|
- pytest -v -s models/quantization
|
||||||
|
|
||||||
- label: Transformers Nightly Models Test
|
- label: Transformers Nightly Models Test
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/"
|
working_dir: "/vllm-workspace/"
|
||||||
@@ -1166,41 +1227,6 @@ steps:
|
|||||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||||
- pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
|
- pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
|
||||||
|
|
||||||
- label: Blackwell Fusion and Compile Tests # 30 min
|
|
||||||
timeout_in_minutes: 40
|
|
||||||
working_dir: "/vllm-workspace/"
|
|
||||||
gpu: b200
|
|
||||||
source_file_dependencies:
|
|
||||||
- csrc/quantization/fp4/
|
|
||||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
|
||||||
- vllm/v1/attention/backends/flashinfer.py
|
|
||||||
- vllm/v1/worker/
|
|
||||||
- vllm/v1/cudagraph_dispatcher.py
|
|
||||||
- vllm/compilation/
|
|
||||||
# can affect pattern matching
|
|
||||||
- vllm/model_executor/layers/layernorm.py
|
|
||||||
- vllm/model_executor/layers/activation.py
|
|
||||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
|
||||||
- tests/compile/passes/test_fusion_attn.py
|
|
||||||
- tests/compile/passes/test_silu_mul_quant_fusion.py
|
|
||||||
- tests/compile/passes/distributed/test_fusion_all_reduce.py
|
|
||||||
- tests/compile/fullgraph/test_full_graph.py
|
|
||||||
commands:
|
|
||||||
- nvidia-smi
|
|
||||||
- pytest -v -s tests/compile/passes/test_fusion_attn.py
|
|
||||||
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
|
|
||||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
|
||||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
|
||||||
|
|
||||||
# # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
|
||||||
# # Wrap with quotes to escape yaml
|
|
||||||
# - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
|
|
||||||
# Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293
|
|
||||||
# in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated.
|
|
||||||
|
|
||||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
|
||||||
- pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile
|
|
||||||
|
|
||||||
- label: Blackwell GPT-OSS Eval
|
- label: Blackwell GPT-OSS Eval
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
working_dir: "/vllm-workspace/"
|
working_dir: "/vllm-workspace/"
|
||||||
@@ -1263,8 +1289,9 @@ steps:
|
|||||||
|
|
||||||
- label: 2 Node Tests (4 GPUs in total) # 16min
|
- label: 2 Node Tests (4 GPUs in total) # 16min
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
mirror_hardwares: [amdexperimental, amdmultinode]
|
mirror_hardwares: [amdexperimental, amdproduction, amdmultinode]
|
||||||
agent_pool: mi325_4
|
agent_pool: mi325_4
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
@@ -1290,8 +1317,9 @@ steps:
|
|||||||
|
|
||||||
- label: Distributed Tests (2 GPUs) # 68min
|
- label: Distributed Tests (2 GPUs) # 68min
|
||||||
timeout_in_minutes: 90
|
timeout_in_minutes: 90
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_2
|
agent_pool: mi325_2
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
@@ -1311,6 +1339,7 @@ steps:
|
|||||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- tests/v1/shutdown
|
- tests/v1/shutdown
|
||||||
- tests/v1/worker/test_worker_memory_snapshot.py
|
- tests/v1/worker/test_worker_memory_snapshot.py
|
||||||
|
- examples/offline_inference/new_weight_syncing/
|
||||||
commands:
|
commands:
|
||||||
# Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876
|
# Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876
|
||||||
# TODO: Remove when the bug is fixed in a future ROCm release
|
# TODO: Remove when the bug is fixed in a future ROCm release
|
||||||
@@ -1324,14 +1353,14 @@ steps:
|
|||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
- pytest -v -s compile/correctness_e2e/test_sequence_parallel.py
|
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||||
- pytest -v -s v1/worker/test_worker_memory_snapshot.py
|
- pytest -v -s v1/worker/test_worker_memory_snapshot.py
|
||||||
|
|
||||||
- label: Distributed Model Tests (2 GPUs) # 37min
|
- label: Distributed Model Tests (2 GPUs) # 37min
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_2
|
agent_pool: mi325_2
|
||||||
|
optional: true
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
@@ -1370,6 +1399,10 @@ steps:
|
|||||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||||
- pip uninstall prithvi_io_processor_plugin -y
|
- pip uninstall prithvi_io_processor_plugin -y
|
||||||
|
# test bge_m3_sparse io_processor plugin
|
||||||
|
- pip install -e ./plugins/bge_m3_sparse_plugin
|
||||||
|
- pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
|
||||||
|
- pip uninstall bge_m3_sparse_plugin -y
|
||||||
# end io_processor plugins test
|
# end io_processor plugins test
|
||||||
# begin stat_logger plugins test
|
# begin stat_logger plugins test
|
||||||
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
||||||
@@ -1441,7 +1474,7 @@ steps:
|
|||||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-amd.txt
|
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-amd.txt
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_2
|
agent_pool: mi325_2
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
@@ -1485,7 +1518,7 @@ steps:
|
|||||||
##### A100 test #####
|
##### A100 test #####
|
||||||
|
|
||||||
- label: Distributed Tests (A100) # optional
|
- label: Distributed Tests (A100) # optional
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_4
|
agent_pool: mi325_4
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
gpu: a100
|
gpu: a100
|
||||||
@@ -1508,7 +1541,7 @@ steps:
|
|||||||
- label: LM Eval Large Models # optional
|
- label: LM Eval Large Models # optional
|
||||||
gpu: a100
|
gpu: a100
|
||||||
optional: true
|
optional: true
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_4
|
agent_pool: mi325_4
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
@@ -1520,11 +1553,11 @@ steps:
|
|||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||||
|
|
||||||
##### H100 test #####
|
##### FP8 test #####
|
||||||
- label: LM Eval Large Models (H100) # optional
|
- label: LM Eval Large Models (H100) # optional, still use H100 for consistency
|
||||||
gpu: h100
|
gpu: h100
|
||||||
optional: true
|
optional: true
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_4
|
agent_pool: mi325_4
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
@@ -1533,13 +1566,13 @@ steps:
|
|||||||
- csrc/
|
- csrc/
|
||||||
- vllm/model_executor/layers/quantization
|
- vllm/model_executor/layers/quantization
|
||||||
commands:
|
commands:
|
||||||
- export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100
|
- export VLLM_USE_DEEP_GEMM=0
|
||||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
|
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-rocm.txt --tp-size=4
|
||||||
|
|
||||||
|
|
||||||
##### H200 test #####
|
##### H200 test #####
|
||||||
- label: Distributed Tests (H200) # optional
|
- label: Distributed Tests (H200) # optional
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_2
|
agent_pool: mi325_2
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
gpu: h200
|
gpu: h200
|
||||||
@@ -1549,16 +1582,16 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py
|
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py
|
||||||
- pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py
|
- pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py
|
||||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
# TODO: this test is not supported on ROCm, there are aiter kernels for this.
|
||||||
|
# - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||||
#- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
#- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
||||||
# - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
|
# - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
|
||||||
# Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293
|
# Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293
|
||||||
# in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated.
|
# in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated.
|
||||||
|
|
||||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py
|
|
||||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||||
- HIP_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=allgather_reducescatter --disable-nccl-for-dp-synchronization
|
- HIP_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=allgather_reducescatter --disable-nccl-for-dp-synchronization
|
||||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
# this test is not supported on ROCm
|
||||||
|
# - pytest -v -s tests/v1/distributed/test_dbo.py
|
||||||
|
|
||||||
##### B200 test #####
|
##### B200 test #####
|
||||||
- label: Distributed Tests (B200) # optional
|
- label: Distributed Tests (B200) # optional
|
||||||
@@ -1599,8 +1632,9 @@ steps:
|
|||||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||||
|
|
||||||
- label: ROCm LM Eval Large Models (8 Card)
|
- label: ROCm LM Eval Large Models (8 Card)
|
||||||
mirror_hardwares: [amdproduction]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_8
|
agent_pool: mi325_8
|
||||||
|
optional: true
|
||||||
num_gpus: 8
|
num_gpus: 8
|
||||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||||
commands:
|
commands:
|
||||||
@@ -1659,7 +1693,7 @@ steps:
|
|||||||
|
|
||||||
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_4
|
agent_pool: mi325_4
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
optional: true
|
optional: true
|
||||||
@@ -1668,6 +1702,93 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||||
|
|
||||||
|
##### .buildkite/test_areas/compile.yaml #####
|
||||||
|
# Slowly setting up the tests so that it is also easier for the
|
||||||
|
# CI team to review and upstream to the pipelinev2.
|
||||||
|
# The following tests are important for vLLM IR Ops refactoring,
|
||||||
|
# which affects fusion passes on ROCm. So we have to
|
||||||
|
# enable them as as soon as possible.
|
||||||
|
|
||||||
|
## TODO: Enable the test in this group
|
||||||
|
# # corresponds to .buildkite/test_areas/compile.yaml
|
||||||
|
# - label: Fusion and Compile Unit Tests (2xMI325 GPUs)
|
||||||
|
# timeout_in_minutes: 20
|
||||||
|
# working_dir: "/vllm-workspace/"
|
||||||
|
# mirror_hardwares: [amdexperimental, amdproduction, tj]
|
||||||
|
# agent_pool: mi325_1 # changed to 1 GPU until the fusion all reduce is enabled then only revert back to 2 GPUs
|
||||||
|
# source_file_dependencies:
|
||||||
|
# - csrc/quantization/fp4/
|
||||||
|
# - vllm/model_executor/layers/quantization/
|
||||||
|
# - vllm/model_executor/layers/layernorm.py
|
||||||
|
# - vllm/model_executor/layers/activation.py
|
||||||
|
# - vllm/model_executor/layers/attention/attention.py
|
||||||
|
# - vllm/v1/attention/backends/flashinfer.py
|
||||||
|
# - vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
|
||||||
|
# - tests/compile/test_fusion_attn.py
|
||||||
|
# - tests/compile/test_silu_mul_quant_fusion.py
|
||||||
|
# - tests/compile/distributed/test_fusion_all_reduce.py
|
||||||
|
# - tests/compile/fullgraph/test_full_graph.py
|
||||||
|
# commands:
|
||||||
|
# - rocm-smi
|
||||||
|
# # we run all backend tests on ROCm
|
||||||
|
# # These two tests are covered in "PyTorch Compilation Passes Unit Tests"
|
||||||
|
# # - "pytest -v -s tests/compile/passes/test_fusion_attn.py"
|
||||||
|
# # - "pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py"
|
||||||
|
# # TODO: this test is not supported on ROCm, there are aiter kernels for this.
|
||||||
|
# # - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||||
|
# # TODO: find out more details
|
||||||
|
# # - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile
|
||||||
|
|
||||||
|
# corresponds to .buildkite/test_areas/compile.yaml
|
||||||
|
- label: Fusion E2E Quick (MI325)
|
||||||
|
timeout_in_minutes: 15
|
||||||
|
working_dir: "/vllm-workspace/"
|
||||||
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
|
agent_pool: mi325_1
|
||||||
|
num_devices: 1
|
||||||
|
source_file_dependencies:
|
||||||
|
- csrc/quantization/
|
||||||
|
- vllm/model_executor/
|
||||||
|
- vllm/v1/attention/
|
||||||
|
- vllm/compilation/
|
||||||
|
- tests/compile/fusions_e2e/
|
||||||
|
commands:
|
||||||
|
- rocm-smi
|
||||||
|
# Run all models and attn backends but only Inductor partition and native custom ops
|
||||||
|
- "pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k 'inductor_partition and not +rms_norm and not +quant_fp8'"
|
||||||
|
# Different from CUDA, Qwen requires +rms_norm and +quant_fp8 as rms+quant fusion is only supported on AITER
|
||||||
|
- "pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k 'inductor_partition and +rms_norm and +quant_fp8 and qwen3'"
|
||||||
|
|
||||||
|
# corresponds to .buildkite/test_areas/compile.yaml
|
||||||
|
- label: Fusion E2E Config Sweep (MI325)
|
||||||
|
timeout_in_minutes: 30
|
||||||
|
working_dir: "/vllm-workspace/"
|
||||||
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
|
agent_pool: mi325_1
|
||||||
|
num_devices: 1
|
||||||
|
source_file_dependencies:
|
||||||
|
- csrc/quantization/
|
||||||
|
- vllm/compilation/
|
||||||
|
# can affect pattern matching
|
||||||
|
- vllm/model_executor/layers/layernorm.py
|
||||||
|
- vllm/model_executor/layers/activation.py
|
||||||
|
- vllm/model_executor/layers/attention/attention.py
|
||||||
|
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||||
|
- tests/compile/fusions_e2e/
|
||||||
|
commands:
|
||||||
|
- rocm-smi
|
||||||
|
# Run just llama3 (fp8) for all config combinations
|
||||||
|
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "llama-3"
|
||||||
|
|
||||||
|
## There are no ops on ROCm for these tests.
|
||||||
|
## The test still passes but the logs are not useful.
|
||||||
|
## fused ops just call torch.ops.symm_mem which
|
||||||
|
## exists in ROCm even though they don't work
|
||||||
|
# - label: AsyncTP Correctness Tests (2xMI325 GPUs)
|
||||||
|
# - label: Fusion E2E TP2 Quick (MI325)
|
||||||
|
# - label: Fusion E2E TP2 AsyncTP Config Sweep (MI325)
|
||||||
|
# - label: Fusion E2E TP2 (MI325)
|
||||||
|
# - label: Sequence Parallel Correctness Tests (2xMI325 GPUs)
|
||||||
|
|
||||||
|
|
||||||
#####################################################################################################################################
|
#####################################################################################################################################
|
||||||
@@ -1850,8 +1971,10 @@ steps:
|
|||||||
|
|
||||||
- label: Distributed Tests (4 GPUs) # 35min
|
- label: Distributed Tests (4 GPUs) # 35min
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi355_4
|
agent_pool: mi355_4
|
||||||
|
optional: true
|
||||||
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@@ -1905,7 +2028,8 @@ steps:
|
|||||||
- popd
|
- popd
|
||||||
# NEW rlhf examples
|
# NEW rlhf examples
|
||||||
- pushd ../examples/offline_inference/new_weight_syncing
|
- pushd ../examples/offline_inference/new_weight_syncing
|
||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
|
||||||
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
|
||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
|
||||||
- popd
|
- popd
|
||||||
|
|
||||||
@@ -2869,8 +2993,10 @@ steps:
|
|||||||
|
|
||||||
- label: Distributed Tests (2 GPUs) # 68min
|
- label: Distributed Tests (2 GPUs) # 68min
|
||||||
timeout_in_minutes: 90
|
timeout_in_minutes: 90
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi355_2
|
agent_pool: mi355_2
|
||||||
|
optional: true
|
||||||
|
# grade: Blocking
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@@ -2946,6 +3072,10 @@ steps:
|
|||||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||||
- pip uninstall prithvi_io_processor_plugin -y
|
- pip uninstall prithvi_io_processor_plugin -y
|
||||||
|
# test bge_m3_sparse io_processor plugin
|
||||||
|
- pip install -e ./plugins/bge_m3_sparse_plugin
|
||||||
|
- pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
|
||||||
|
- pip uninstall bge_m3_sparse_plugin -y
|
||||||
# end io_processor plugins test
|
# end io_processor plugins test
|
||||||
# begin stat_logger plugins test
|
# begin stat_logger plugins test
|
||||||
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
||||||
|
|||||||
@@ -103,7 +103,8 @@ steps:
|
|||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||||
# NEW rlhf examples
|
# NEW rlhf examples
|
||||||
- cd new_weight_syncing
|
- cd new_weight_syncing
|
||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
|
||||||
|
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
|
||||||
|
|
||||||
- label: Distributed Tests (8 GPUs)(H100)
|
- label: Distributed Tests (8 GPUs)(H100)
|
||||||
timeout_in_minutes: 10
|
timeout_in_minutes: 10
|
||||||
@@ -145,7 +146,7 @@ steps:
|
|||||||
num_devices: 2
|
num_devices: 2
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py
|
# - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py --- failing, need to re-enable
|
||||||
- VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
|
- VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
|
||||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||||
|
|
||||||
- label: V1 e2e + engine
|
- label: V1 e2e + engine (1 GPU)
|
||||||
timeout_in_minutes: 45
|
timeout_in_minutes: 45
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@@ -36,3 +36,35 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s v1/e2e
|
- pytest -v -s v1/e2e
|
||||||
- pytest -v -s v1/engine
|
- pytest -v -s v1/engine
|
||||||
|
|
||||||
|
- label: V1 e2e (2 GPUs)
|
||||||
|
timeout_in_minutes: 60 # TODO: Fix timeout after we have more confidence in the test stability
|
||||||
|
optional: true
|
||||||
|
num_devices: 2
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/v1/e2e
|
||||||
|
commands:
|
||||||
|
# Only run tests that need exactly 2 GPUs
|
||||||
|
- pytest -v -s v1/e2e/test_spec_decode.py -k "tensor_parallelism"
|
||||||
|
mirror:
|
||||||
|
amd:
|
||||||
|
device: mi325_2
|
||||||
|
depends_on:
|
||||||
|
- image-build-amd
|
||||||
|
|
||||||
|
- label: V1 e2e (4 GPUs)
|
||||||
|
timeout_in_minutes: 60 # TODO: Fix timeout after we have more confidence in the test stability
|
||||||
|
optional: true
|
||||||
|
num_devices: 4
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/v1/e2e
|
||||||
|
commands:
|
||||||
|
# Only run tests that need 4 GPUs
|
||||||
|
- pytest -v -s v1/e2e/test_spec_decode.py -k "eagle_correctness_heavy"
|
||||||
|
mirror:
|
||||||
|
amd:
|
||||||
|
device: mi325_4
|
||||||
|
depends_on:
|
||||||
|
- image-build-amd
|
||||||
|
|||||||
@@ -21,3 +21,18 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s distributed/test_eplb_execute.py
|
- pytest -v -s distributed/test_eplb_execute.py
|
||||||
- pytest -v -s distributed/test_eplb_spec_decode.py
|
- pytest -v -s distributed/test_eplb_spec_decode.py
|
||||||
|
|
||||||
|
- label: Elastic EP Scaling Test
|
||||||
|
timeout_in_minutes: 20
|
||||||
|
device: b200
|
||||||
|
optional: true
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_devices: 4
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/distributed/
|
||||||
|
- vllm/engine/
|
||||||
|
- vllm/executor/
|
||||||
|
- vllm/compilation/
|
||||||
|
- tests/distributed/
|
||||||
|
commands:
|
||||||
|
- pytest -v -s distributed/test_elastic_ep.py
|
||||||
|
|||||||
@@ -44,7 +44,8 @@ steps:
|
|||||||
- vllm/envs.py
|
- vllm/envs.py
|
||||||
- vllm/config
|
- vllm/config
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
- pytest -v -s kernels/moe --ignore=kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
|
- pytest -v -s kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
parallelism: 2
|
parallelism: 2
|
||||||
|
|
||||||
- label: Kernels Mamba Test
|
- label: Kernels Mamba Test
|
||||||
@@ -70,7 +71,7 @@ steps:
|
|||||||
- tests/kernels/moe/test_batched_deepgemm.py
|
- tests/kernels/moe/test_batched_deepgemm.py
|
||||||
- tests/kernels/attention/test_deepgemm_attention.py
|
- tests/kernels/attention/test_deepgemm_attention.py
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm
|
- pytest -v -s kernels/quantization/test_block_fp8.py
|
||||||
- pytest -v -s kernels/moe/test_deepgemm.py
|
- pytest -v -s kernels/moe/test_deepgemm.py
|
||||||
- pytest -v -s kernels/moe/test_batched_deepgemm.py
|
- pytest -v -s kernels/moe/test_batched_deepgemm.py
|
||||||
- pytest -v -s kernels/attention/test_deepgemm_attention.py
|
- pytest -v -s kernels/attention/test_deepgemm_attention.py
|
||||||
@@ -155,5 +156,14 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/moe/test_deepep_deepgemm_moe.py
|
- pytest -v -s kernels/moe/test_deepep_deepgemm_moe.py
|
||||||
- pytest -v -s kernels/moe/test_deepep_moe.py
|
- pytest -v -s kernels/moe/test_deepep_moe.py
|
||||||
- pytest -v -s kernels/moe/test_pplx_cutlass_moe.py
|
|
||||||
# - pytest -v -s kernels/moe/test_pplx_moe.py - failing on main
|
- label: Kernels Fp4 MoE Test (B200)
|
||||||
|
timeout_in_minutes: 60
|
||||||
|
device: b200
|
||||||
|
num_devices: 1
|
||||||
|
optional: true
|
||||||
|
commands:
|
||||||
|
- pytest -v -s kernels/moe/test_cutedsl_moe.py
|
||||||
|
- pytest -v -s kernels/moe/test_flashinfer_moe.py
|
||||||
|
- pytest -v -s kernels/moe/test_nvfp4_moe.py
|
||||||
|
- pytest -v -s kernels/moe/test_ocp_mx_moe.py
|
||||||
|
|||||||
@@ -11,17 +11,17 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||||
|
|
||||||
- label: LM Eval Large Models (4 GPUs)(A100)
|
# - label: LM Eval Large Models (4 GPUs)(A100)
|
||||||
device: a100
|
# device: a100
|
||||||
optional: true
|
# optional: true
|
||||||
num_devices: 4
|
# num_devices: 4
|
||||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
# working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||||
source_file_dependencies:
|
# source_file_dependencies:
|
||||||
- csrc/
|
# - csrc/
|
||||||
- vllm/model_executor/layers/quantization
|
# - vllm/model_executor/layers/quantization
|
||||||
commands:
|
# commands:
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
# - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||||
|
|
||||||
- label: LM Eval Large Models (4 GPUs)(H100)
|
- label: LM Eval Large Models (4 GPUs)(H100)
|
||||||
device: h100
|
device: h100
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ steps:
|
|||||||
- tests/v1
|
- tests/v1
|
||||||
commands:
|
commands:
|
||||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||||
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
# split the test to avoid interference
|
# split the test to avoid interference
|
||||||
- pytest -v -s -m 'not cpu_test' v1/core
|
- pytest -v -s -m 'not cpu_test' v1/core
|
||||||
- pytest -v -s v1/executor
|
- pytest -v -s v1/executor
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/multimodal
|
- tests/models/multimodal
|
||||||
|
- tests/models/registry.py
|
||||||
device: cpu
|
device: cpu
|
||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
@@ -30,6 +31,7 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/multimodal
|
- tests/models/multimodal
|
||||||
|
- tests/models/registry.py
|
||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
||||||
@@ -70,12 +72,3 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
||||||
|
|
||||||
# This test is used only in PR development phase to test individual models and should never run on main
|
|
||||||
- label: Custom Models
|
|
||||||
optional: true
|
|
||||||
commands:
|
|
||||||
- echo 'Testing custom models...'
|
|
||||||
# PR authors can temporarily add commands below to test individual models
|
|
||||||
# e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py
|
|
||||||
# *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR*
|
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ steps:
|
|||||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||||
- pip uninstall prithvi_io_processor_plugin -y
|
- pip uninstall prithvi_io_processor_plugin -y
|
||||||
|
# test bge_m3_sparse io_processor plugin
|
||||||
|
- pip install -e ./plugins/bge_m3_sparse_plugin
|
||||||
|
- pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
|
||||||
|
- pip uninstall bge_m3_sparse_plugin -y
|
||||||
# end io_processor plugins test
|
# end io_processor plugins test
|
||||||
# begin stat_logger plugins test
|
# begin stat_logger plugins test
|
||||||
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
||||||
|
|||||||
16
.buildkite/test_areas/ray_compat.yaml
Normal file
16
.buildkite/test_areas/ray_compat.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
group: Ray Compatibility
|
||||||
|
depends_on:
|
||||||
|
- image-build
|
||||||
|
steps:
|
||||||
|
- label: Ray Dependency Compatibility Check
|
||||||
|
# Informational only — does not block the pipeline.
|
||||||
|
# If this fails, it means the PR introduces a dependency that
|
||||||
|
# conflicts with Ray's dependency constraints.
|
||||||
|
# See https://github.com/vllm-project/vllm/issues/33599
|
||||||
|
soft_fail: true
|
||||||
|
timeout_in_minutes: 10
|
||||||
|
source_file_dependencies:
|
||||||
|
- requirements/
|
||||||
|
- setup.py
|
||||||
|
commands:
|
||||||
|
- bash /vllm-workspace/.buildkite/scripts/check-ray-compatibility.sh
|
||||||
@@ -13,13 +13,13 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU - Large Models # optional
|
# - label: Weight Loading Multiple GPU - Large Models # optional
|
||||||
working_dir: "/vllm-workspace/tests"
|
# working_dir: "/vllm-workspace/tests"
|
||||||
num_devices: 2
|
# num_devices: 2
|
||||||
device: a100
|
# device: a100
|
||||||
optional: true
|
# optional: true
|
||||||
source_file_dependencies:
|
# source_file_dependencies:
|
||||||
- vllm/
|
# - vllm/
|
||||||
- tests/weight_loading
|
# - tests/weight_loading
|
||||||
commands:
|
# commands:
|
||||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
# - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||||
|
|||||||
24
.github/.bc-linter.yml
vendored
24
.github/.bc-linter.yml
vendored
@@ -1,24 +0,0 @@
|
|||||||
# doc: https://github.com/pytorch/test-infra/blob/main/tools/stronghold/docs/bc_linter_config.md
|
|
||||||
version: 1
|
|
||||||
paths:
|
|
||||||
# We temporarily disable globally, and will only enable with `annotations.include`
|
|
||||||
# include:
|
|
||||||
# - "vllm/v1/attetion/*.py"
|
|
||||||
# - "vllm/v1/core/*.py"
|
|
||||||
exclude:
|
|
||||||
- "**/*.py"
|
|
||||||
|
|
||||||
scan:
|
|
||||||
functions: true # check free functions and methods
|
|
||||||
classes: true # check classes/dataclasses
|
|
||||||
public_only: true # ignore names starting with "_" at any level
|
|
||||||
|
|
||||||
annotations:
|
|
||||||
include: # decorators that force‑include a symbol
|
|
||||||
- name: "bc_linter_include" # matched by simple name or dotted suffix
|
|
||||||
propagate_to_members: false # for classes, include methods/inner classes
|
|
||||||
exclude: # decorators that force‑exclude a symbol
|
|
||||||
- name: "bc_linter_skip" # matched by simple name or dotted suffix
|
|
||||||
propagate_to_members: true # for classes, exclude methods/inner classes
|
|
||||||
|
|
||||||
excluded_violations: [] # e.g. ["ParameterRenamed", "FieldTypeChanged"]
|
|
||||||
9
.github/CODEOWNERS
vendored
9
.github/CODEOWNERS
vendored
@@ -2,7 +2,7 @@
|
|||||||
# for more info about CODEOWNERS file
|
# for more info about CODEOWNERS file
|
||||||
|
|
||||||
# This lists cover the "core" components of vLLM that require careful review
|
# This lists cover the "core" components of vLLM that require careful review
|
||||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg
|
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng
|
||||||
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
|
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
|
||||||
/vllm/lora @jeejeelee
|
/vllm/lora @jeejeelee
|
||||||
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
|
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
|
||||||
@@ -54,11 +54,14 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
||||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||||
/vllm/v1/kv_offload @ApostaC @orozery
|
/vllm/v1/kv_offload @ApostaC @orozery
|
||||||
/vllm/v1/worker/gpu/kv_connector.py @orozery
|
/vllm/v1/engine @njhill
|
||||||
|
/vllm/v1/executor @njhill
|
||||||
|
/vllm/v1/worker @njhill
|
||||||
/vllm/v1/worker/kv_connector_model_runner_mixin.py @orozery @NickLucche
|
/vllm/v1/worker/kv_connector_model_runner_mixin.py @orozery @NickLucche
|
||||||
|
|
||||||
# Model runner V2
|
# Model runner V2
|
||||||
/vllm/v1/worker/gpu @WoosukKwon
|
/vllm/v1/worker/gpu @WoosukKwon @njhill
|
||||||
|
/vllm/v1/worker/gpu/kv_connector.py @orozery
|
||||||
|
|
||||||
# Test ownership
|
# Test ownership
|
||||||
/.buildkite/lm-eval-harness @mgoin
|
/.buildkite/lm-eval-harness @mgoin
|
||||||
|
|||||||
3
.github/mergify.yml
vendored
3
.github/mergify.yml
vendored
@@ -259,8 +259,7 @@ pull_request_rules:
|
|||||||
- files=benchmarks/run_structured_output_benchmark.sh
|
- files=benchmarks/run_structured_output_benchmark.sh
|
||||||
- files=docs/features/structured_outputs.md
|
- files=docs/features/structured_outputs.md
|
||||||
- files=examples/offline_inference/structured_outputs.py
|
- files=examples/offline_inference/structured_outputs.py
|
||||||
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
- files=examples/online_serving/structured_outputs/structured_outputs.py
|
||||||
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
|
||||||
- files~=^tests/v1/structured_output/
|
- files~=^tests/v1/structured_output/
|
||||||
- files=tests/v1/entrypoints/llm/test_struct_output_generate.py
|
- files=tests/v1/entrypoints/llm/test_struct_output_generate.py
|
||||||
- files~=^vllm/v1/structured_output/
|
- files~=^vllm/v1/structured_output/
|
||||||
|
|||||||
29
.github/workflows/bc-lint.yml
vendored
29
.github/workflows/bc-lint.yml
vendored
@@ -1,29 +0,0 @@
|
|||||||
name: BC Lint
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types:
|
|
||||||
- opened
|
|
||||||
- synchronize
|
|
||||||
- reopened
|
|
||||||
- labeled
|
|
||||||
- unlabeled
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
bc_lint:
|
|
||||||
if: github.repository_owner == 'vllm-project'
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Run BC Lint Action
|
|
||||||
uses: pytorch/test-infra/.github/actions/bc-lint@main
|
|
||||||
with:
|
|
||||||
repo: ${{ github.event.pull_request.head.repo.full_name }}
|
|
||||||
base_sha: ${{ github.event.pull_request.base.sha }}
|
|
||||||
head_sha: ${{ github.event.pull_request.head.sha }}
|
|
||||||
suppression: ${{ contains(github.event.pull_request.labels.*.name, 'suppress-bc-linter') }}
|
|
||||||
docs_link: 'https://github.com/pytorch/test-infra/wiki/BC-Linter'
|
|
||||||
config_dir: .github
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -3,6 +3,8 @@
|
|||||||
|
|
||||||
# vllm-flash-attn built from source
|
# vllm-flash-attn built from source
|
||||||
vllm/vllm_flash_attn/*
|
vllm/vllm_flash_attn/*
|
||||||
|
!vllm/vllm_flash_attn/__init__.py
|
||||||
|
!vllm/vllm_flash_attn/flash_attn_interface.py
|
||||||
|
|
||||||
# OpenAI triton kernels copied from source
|
# OpenAI triton kernels copied from source
|
||||||
vllm/third_party/triton_kernels/*
|
vllm/third_party/triton_kernels/*
|
||||||
|
|||||||
@@ -725,7 +725,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# CUTLASS MoE kernels
|
# CUTLASS MoE kernels
|
||||||
|
|
||||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
|
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
|
||||||
# on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
|
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
|
||||||
# if it's possible to compile MoE kernels that use its output.
|
# if it's possible to compile MoE kernels that use its output.
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||||
@@ -771,6 +771,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
|
||||||
|
"csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}")
|
||||||
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1")
|
||||||
|
message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8
|
||||||
|
AND ES_MXFP8_GROUPED_MM_ARCHS)
|
||||||
|
message(STATUS "Not building ES MXFP8 grouped kernels as CUDA Compiler version is "
|
||||||
|
"not >= 12.8.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building ES MXFP8 grouped kernels as no compatible archs found "
|
||||||
|
"in CUDA target architectures.")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
|
# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
|
||||||
@@ -971,7 +998,8 @@ set(VLLM_MOE_EXT_SRC
|
|||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
list(APPEND VLLM_MOE_EXT_SRC
|
list(APPEND VLLM_MOE_EXT_SRC
|
||||||
"csrc/moe/moe_wna16.cu"
|
"csrc/moe/moe_wna16.cu"
|
||||||
"csrc/moe/grouped_topk_kernels.cu")
|
"csrc/moe/grouped_topk_kernels.cu"
|
||||||
|
"csrc/moe/router_gemm.cu")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from .common import (
|
|||||||
BenchmarkConfig,
|
BenchmarkConfig,
|
||||||
BenchmarkResult,
|
BenchmarkResult,
|
||||||
MockLayer,
|
MockLayer,
|
||||||
MockModelConfig,
|
|
||||||
ResultsFormatter,
|
ResultsFormatter,
|
||||||
get_attention_scale,
|
get_attention_scale,
|
||||||
is_mla_backend,
|
is_mla_backend,
|
||||||
@@ -36,7 +35,6 @@ __all__ = [
|
|||||||
"ResultsFormatter",
|
"ResultsFormatter",
|
||||||
# Mock objects
|
# Mock objects
|
||||||
"MockLayer",
|
"MockLayer",
|
||||||
"MockModelConfig",
|
|
||||||
# Utilities
|
# Utilities
|
||||||
"setup_mla_dims",
|
"setup_mla_dims",
|
||||||
"get_attention_scale",
|
"get_attention_scale",
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from dataclasses import asdict, dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from batch_spec import get_batch_type, parse_batch_spec
|
from batch_spec import get_batch_type, parse_batch_spec
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@@ -62,10 +61,7 @@ class MockHfConfig:
|
|||||||
# Import AttentionLayerBase at module level to avoid circular dependencies
|
# Import AttentionLayerBase at module level to avoid circular dependencies
|
||||||
try:
|
try:
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
|
||||||
_HAS_ATTENTION_LAYER_BASE = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_HAS_ATTENTION_LAYER_BASE = False
|
|
||||||
AttentionLayerBase = object # Fallback
|
AttentionLayerBase = object # Fallback
|
||||||
|
|
||||||
|
|
||||||
@@ -167,95 +163,6 @@ class MockLayer(AttentionLayerBase):
|
|||||||
return self._kv_cache_spec
|
return self._kv_cache_spec
|
||||||
|
|
||||||
|
|
||||||
class MockModelConfig:
|
|
||||||
"""Mock model configuration."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_q_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
dtype: torch.dtype = torch.float16,
|
|
||||||
max_model_len: int = 32768,
|
|
||||||
):
|
|
||||||
self._n_q = num_q_heads
|
|
||||||
self._n_kv = num_kv_heads
|
|
||||||
self._d = head_dim
|
|
||||||
self.dtype = dtype
|
|
||||||
self.max_model_len = max_model_len
|
|
||||||
|
|
||||||
def get_num_attention_heads(self, _=None) -> int:
|
|
||||||
return self._n_q
|
|
||||||
|
|
||||||
def get_num_kv_heads(self, _=None) -> int:
|
|
||||||
return self._n_kv
|
|
||||||
|
|
||||||
def get_head_size(self) -> int:
|
|
||||||
return self._d
|
|
||||||
|
|
||||||
def get_num_layers(self) -> int:
|
|
||||||
"""Mock method for layer count queries."""
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def get_sliding_window_for_layer(self, _layer_idx: int):
|
|
||||||
"""Mock method for sliding window queries."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_logits_soft_cap_for_layer(self, _layer_idx: int):
|
|
||||||
"""Mock method for logits soft cap queries."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_sm_scale_for_layer(self, _layer_idx: int) -> float:
|
|
||||||
"""Mock method for SM scale queries."""
|
|
||||||
return 1.0 / (self.get_head_size() ** 0.5)
|
|
||||||
|
|
||||||
|
|
||||||
class MockParallelConfig:
|
|
||||||
"""Mock parallel configuration."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MockCompilationConfig:
|
|
||||||
"""Mock compilation configuration."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.full_cuda_graph = False
|
|
||||||
self.static_forward_context = {}
|
|
||||||
|
|
||||||
|
|
||||||
class MockVLLMConfig:
|
|
||||||
"""Mock VLLM configuration."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.compilation_config = MockCompilationConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class MockRunner:
|
|
||||||
"""Mock GPU runner for metadata builders."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
seq_lens: np.ndarray,
|
|
||||||
query_start_locs: np.ndarray,
|
|
||||||
device: torch.device,
|
|
||||||
num_q_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
self.model_config = MockModelConfig(num_q_heads, num_kv_heads, head_dim, dtype)
|
|
||||||
self.parallel_config = MockParallelConfig()
|
|
||||||
self.vllm_config = MockVLLMConfig()
|
|
||||||
self.seq_lens_np = seq_lens
|
|
||||||
self.query_start_loc_np = query_start_locs
|
|
||||||
self.device = device
|
|
||||||
self.attention_chunk_size = None
|
|
||||||
self.num_query_heads = num_q_heads
|
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParameterSweep:
|
class ParameterSweep:
|
||||||
"""Configuration for sweeping a backend parameter."""
|
"""Configuration for sweeping a backend parameter."""
|
||||||
|
|||||||
@@ -85,7 +85,6 @@ start_server() {
|
|||||||
# Each argument and its value are separate elements.
|
# Each argument and its value are separate elements.
|
||||||
local common_args_array=(
|
local common_args_array=(
|
||||||
"$MODEL"
|
"$MODEL"
|
||||||
"--disable-log-requests"
|
|
||||||
"--port" "8004"
|
"--port" "8004"
|
||||||
"--host" "$HOSTNAME"
|
"--host" "$HOSTNAME"
|
||||||
"--gpu-memory-utilization" "$gpu_memory_utilization"
|
"--gpu-memory-utilization" "$gpu_memory_utilization"
|
||||||
|
|||||||
@@ -649,9 +649,3 @@ ASYNC_REQUEST_FUNCS = {
|
|||||||
"sglang": async_request_openai_completions,
|
"sglang": async_request_openai_completions,
|
||||||
"llama.cpp": async_request_openai_completions,
|
"llama.cpp": async_request_openai_completions,
|
||||||
}
|
}
|
||||||
|
|
||||||
OPENAI_COMPATIBLE_BACKENDS = [
|
|
||||||
k
|
|
||||||
for k, v in ASYNC_REQUEST_FUNCS.items()
|
|
||||||
if v in (async_request_openai_completions, async_request_openai_chat_completions)
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,78 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_pytorch_benchmark_format(
|
|
||||||
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
|
|
||||||
) -> list:
|
|
||||||
"""
|
|
||||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
|
||||||
on metric per record
|
|
||||||
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
|
|
||||||
"""
|
|
||||||
records = []
|
|
||||||
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
|
|
||||||
return records
|
|
||||||
|
|
||||||
for name, benchmark_values in metrics.items():
|
|
||||||
record = {
|
|
||||||
"benchmark": {
|
|
||||||
"name": "vLLM benchmark",
|
|
||||||
"extra_info": {
|
|
||||||
"args": vars(args),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"model": {
|
|
||||||
"name": args.model,
|
|
||||||
},
|
|
||||||
"metric": {
|
|
||||||
"name": name,
|
|
||||||
"benchmark_values": benchmark_values,
|
|
||||||
"extra_info": extra_info,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
|
||||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
|
||||||
if not tp and "tensor_parallel_size" in extra_info:
|
|
||||||
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
|
|
||||||
extra_info["tensor_parallel_size"]
|
|
||||||
)
|
|
||||||
|
|
||||||
records.append(record)
|
|
||||||
|
|
||||||
return records
|
|
||||||
|
|
||||||
|
|
||||||
class InfEncoder(json.JSONEncoder):
|
|
||||||
def clear_inf(self, o: Any):
|
|
||||||
if isinstance(o, dict):
|
|
||||||
return {k: self.clear_inf(v) for k, v in o.items()}
|
|
||||||
elif isinstance(o, list):
|
|
||||||
return [self.clear_inf(v) for v in o]
|
|
||||||
elif isinstance(o, float) and math.isinf(o):
|
|
||||||
return "inf"
|
|
||||||
return o
|
|
||||||
|
|
||||||
def iterencode(self, o: Any, *args, **kwargs) -> Any:
|
|
||||||
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def write_to_json(filename: str, records: list) -> None:
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump(
|
|
||||||
records,
|
|
||||||
f,
|
|
||||||
cls=InfEncoder,
|
|
||||||
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Collect time and generate time metrics
|
# Collect time and generate time metrics
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
# Cutlass bench utils
|
# Cutlass bench utils
|
||||||
from collections.abc import Iterable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -86,15 +85,3 @@ def make_rand_sparse_tensors(
|
|||||||
|
|
||||||
# Compressed B, Metadata, Original A, B
|
# Compressed B, Metadata, Original A, B
|
||||||
return b_compressed, e, a, b
|
return b_compressed, e, a, b
|
||||||
|
|
||||||
|
|
||||||
def make_n_rand_sparse_tensors(
|
|
||||||
num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
|
|
||||||
) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
|
||||||
ABs = []
|
|
||||||
for _ in range(num_tensors):
|
|
||||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
|
||||||
if b_comp is not None:
|
|
||||||
ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
|
|
||||||
BComps, Es, As, Bs = zip(*ABs)
|
|
||||||
return list(BComps), list(Es), list(As), list(Bs)
|
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
|
||||||
"""Token bucket rate limiter implementation"""
|
|
||||||
|
|
||||||
def __init__(self, rate_limit):
|
|
||||||
self.rate_limit = rate_limit # Requests per second
|
|
||||||
self.num_available_tokens = rate_limit # Available tokens
|
|
||||||
self.last_refill = time.monotonic() # Last token refill time
|
|
||||||
self.lock = asyncio.Lock() # Synchronization lock
|
|
||||||
|
|
||||||
async def acquire(self):
|
|
||||||
"""Acquire a token from the rate limiter"""
|
|
||||||
while True:
|
|
||||||
async with self.lock:
|
|
||||||
current_time = time.monotonic()
|
|
||||||
elapsed = current_time - self.last_refill
|
|
||||||
|
|
||||||
# Refill num_available_tokens if more than 1 second has passed
|
|
||||||
if elapsed > 1.0:
|
|
||||||
self.num_available_tokens = self.rate_limit
|
|
||||||
self.last_refill = current_time
|
|
||||||
|
|
||||||
# Check if num_available_tokens are available
|
|
||||||
if self.num_available_tokens > 0:
|
|
||||||
self.num_available_tokens -= 1
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Calculate wait time if no num_available_tokens available
|
|
||||||
wait_time = 1.0 - elapsed
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
"""Enter async context manager - acquire token"""
|
|
||||||
await self.acquire()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
||||||
"""Exit async context manager - no cleanup needed"""
|
|
||||||
pass
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
|
|
||||||
class RequestQueue:
|
|
||||||
"""Request queue manager with concurrency control"""
|
|
||||||
|
|
||||||
def __init__(self, max_concurrent, max_queue_size):
|
|
||||||
# Maximum concurrent requests
|
|
||||||
self.max_concurrent = max_concurrent
|
|
||||||
self.max_queue_size = max_queue_size # Maximum queue size
|
|
||||||
# Concurrency control
|
|
||||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
|
||||||
self.queue = deque() # Request queue
|
|
||||||
self.queue_size = 0 # Current queue size
|
|
||||||
self.lock = asyncio.Lock() # Sync queue Lock
|
|
||||||
|
|
||||||
async def enqueue(self, task):
|
|
||||||
"""Add a request task to the queue"""
|
|
||||||
async with self.lock:
|
|
||||||
if self.queue_size >= self.max_queue_size:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.queue.append(task)
|
|
||||||
self.queue_size += 1
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def process(self):
|
|
||||||
"""Process queued requests using semaphore for concurrency control"""
|
|
||||||
while True:
|
|
||||||
if self.queue:
|
|
||||||
async with self.semaphore, self.lock:
|
|
||||||
task = self.queue.popleft()
|
|
||||||
self.queue_size -= 1
|
|
||||||
await task
|
|
||||||
await asyncio.sleep(0.01) # Yield control to event loop
|
|
||||||
@@ -12,12 +12,12 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||||
|
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||||
|
maybe_make_prepare_finalize,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
||||||
MoEPrepareAndFinalizeNoEP,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.v1.worker.workspace import init_workspace_manager
|
from vllm.v1.worker.workspace import init_workspace_manager
|
||||||
@@ -137,15 +137,21 @@ def bench_run(
|
|||||||
per_out_ch_quant=per_out_ch,
|
per_out_ch_quant=per_out_ch,
|
||||||
)
|
)
|
||||||
|
|
||||||
fn = mk.FusedMoEModularKernel(
|
moe_config = make_dummy_moe_config(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
num_experts=num_experts,
|
||||||
|
hidden_dim=k,
|
||||||
|
intermediate_size_per_partition=n,
|
||||||
|
in_dtype=a.dtype,
|
||||||
|
)
|
||||||
|
fn = mk.FusedMoEKernel(
|
||||||
|
maybe_make_prepare_finalize(
|
||||||
|
moe=moe_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_new_interface=True,
|
||||||
|
use_monolithic=False,
|
||||||
|
),
|
||||||
CutlassExpertsFp8(
|
CutlassExpertsFp8(
|
||||||
moe_config=make_dummy_moe_config(
|
moe_config=moe_config,
|
||||||
num_experts=num_experts,
|
|
||||||
hidden_dim=k,
|
|
||||||
intermediate_size_per_partition=n,
|
|
||||||
in_dtype=a.dtype,
|
|
||||||
),
|
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||||
|
maybe_make_prepare_finalize,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
fp8_w8a8_moe_quant_config,
|
fp8_w8a8_moe_quant_config,
|
||||||
nvfp4_moe_quant_config,
|
nvfp4_moe_quant_config,
|
||||||
@@ -23,9 +26,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
|||||||
CutlassExpertsFp4,
|
CutlassExpertsFp4,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
||||||
MoEPrepareAndFinalizeNoEP,
|
|
||||||
)
|
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.v1.worker.workspace import init_workspace_manager
|
from vllm.v1.worker.workspace import init_workspace_manager
|
||||||
@@ -196,10 +196,21 @@ def bench_run(
|
|||||||
g2_alphas=w2_gs,
|
g2_alphas=w2_gs,
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel = mk.FusedMoEModularKernel(
|
moe_config = make_dummy_moe_config(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
num_experts=num_experts,
|
||||||
|
hidden_dim=k,
|
||||||
|
intermediate_size_per_partition=n,
|
||||||
|
in_dtype=a.dtype,
|
||||||
|
)
|
||||||
|
kernel = mk.FusedMoEKernel(
|
||||||
|
maybe_make_prepare_finalize(
|
||||||
|
moe=moe_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_new_interface=True,
|
||||||
|
use_monolithic=False,
|
||||||
|
),
|
||||||
CutlassExpertsFp4(
|
CutlassExpertsFp4(
|
||||||
make_dummy_moe_config(),
|
moe_config=moe_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -240,11 +251,17 @@ def bench_run(
|
|||||||
g1_alphas=w1_gs,
|
g1_alphas=w1_gs,
|
||||||
g2_alphas=w2_gs,
|
g2_alphas=w2_gs,
|
||||||
)
|
)
|
||||||
|
moe_config = make_dummy_moe_config()
|
||||||
|
|
||||||
kernel = mk.FusedMoEModularKernel(
|
kernel = mk.FusedMoEKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
maybe_make_prepare_finalize(
|
||||||
|
moe=moe_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_new_interface=True,
|
||||||
|
use_monolithic=False,
|
||||||
|
),
|
||||||
CutlassExpertsFp4(
|
CutlassExpertsFp4(
|
||||||
make_dummy_moe_config(),
|
moe_config=moe_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||||
|
maybe_make_prepare_finalize,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts,
|
fused_experts,
|
||||||
fused_topk,
|
fused_topk,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
||||||
MoEPrepareAndFinalizeNoEP,
|
|
||||||
)
|
|
||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.v1.worker.workspace import init_workspace_manager
|
from vllm.v1.worker.workspace import init_workspace_manager
|
||||||
|
|
||||||
@@ -131,16 +131,22 @@ def bench_run(
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
per_act_token_quant=per_act_token,
|
per_act_token_quant=per_act_token,
|
||||||
)
|
)
|
||||||
|
moe_config = make_dummy_moe_config(
|
||||||
|
num_experts=w2.shape[0],
|
||||||
|
hidden_dim=w2.shape[1],
|
||||||
|
intermediate_size_per_partition=w2.shape[2],
|
||||||
|
in_dtype=a.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
maybe_make_prepare_finalize(
|
||||||
|
moe=moe_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_new_interface=True,
|
||||||
|
use_monolithic=False,
|
||||||
|
),
|
||||||
CutlassExpertsFp8(
|
CutlassExpertsFp8(
|
||||||
moe_config=make_dummy_moe_config(
|
moe_config=moe_config,
|
||||||
num_experts=w2.shape[0],
|
|
||||||
hidden_dim=w2.shape[1],
|
|
||||||
intermediate_size_per_partition=w2.shape[2],
|
|
||||||
in_dtype=a.dtype,
|
|
||||||
),
|
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -163,16 +169,22 @@ def bench_run(
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
per_act_token_quant=per_act_token,
|
per_act_token_quant=per_act_token,
|
||||||
)
|
)
|
||||||
|
moe_config = make_dummy_moe_config(
|
||||||
|
num_experts=w2.shape[0],
|
||||||
|
hidden_dim=w2.shape[1],
|
||||||
|
intermediate_size_per_partition=w2.shape[2],
|
||||||
|
in_dtype=a.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
maybe_make_prepare_finalize(
|
||||||
|
moe=moe_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_new_interface=True,
|
||||||
|
use_monolithic=False,
|
||||||
|
),
|
||||||
CutlassExpertsFp8(
|
CutlassExpertsFp8(
|
||||||
moe_config=make_dummy_moe_config(
|
moe_config=moe_config,
|
||||||
num_experts=w2.shape[0],
|
|
||||||
hidden_dim=w2.shape[1],
|
|
||||||
intermediate_size_per_partition=w2.shape[2],
|
|
||||||
in_dtype=a.dtype,
|
|
||||||
),
|
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ from ray.experimental.tqdm_ray import tqdm
|
|||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||||
|
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||||
|
maybe_make_prepare_finalize,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEConfig,
|
FusedMoEConfig,
|
||||||
FusedMoEParallelConfig,
|
FusedMoEParallelConfig,
|
||||||
@@ -242,24 +245,33 @@ def benchmark_config(
|
|||||||
|
|
||||||
deep_gemm_experts = None
|
deep_gemm_experts = None
|
||||||
if use_deep_gemm:
|
if use_deep_gemm:
|
||||||
deep_gemm_experts = mk.FusedMoEModularKernel(
|
moe_config = (
|
||||||
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
|
FusedMoEConfig(
|
||||||
|
num_experts=num_experts,
|
||||||
|
experts_per_token=topk,
|
||||||
|
hidden_dim=hidden_size,
|
||||||
|
intermediate_size_per_partition=shard_intermediate_size,
|
||||||
|
num_local_experts=num_experts,
|
||||||
|
num_logical_experts=num_experts,
|
||||||
|
activation=MoEActivation.SILU,
|
||||||
|
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||||
|
in_dtype=init_dtype,
|
||||||
|
routing_method=RoutingMethodType.TopK,
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
deep_gemm_experts = mk.FusedMoEKernel(
|
||||||
|
prepare_finalize=maybe_make_prepare_finalize(
|
||||||
|
moe=moe_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_new_interface=True,
|
||||||
|
use_monolithic=False,
|
||||||
|
),
|
||||||
fused_experts=TritonOrDeepGemmExperts(
|
fused_experts=TritonOrDeepGemmExperts(
|
||||||
moe_config=FusedMoEConfig(
|
moe_config=moe_config,
|
||||||
num_experts=num_experts,
|
|
||||||
experts_per_token=topk,
|
|
||||||
hidden_dim=hidden_size,
|
|
||||||
intermediate_size_per_partition=shard_intermediate_size,
|
|
||||||
num_local_experts=num_experts,
|
|
||||||
num_logical_experts=num_experts,
|
|
||||||
activation=MoEActivation.SILU,
|
|
||||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
|
||||||
in_dtype=init_dtype,
|
|
||||||
routing_method=RoutingMethodType.TopK,
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
|
inplace=not disable_inplace(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with override_config(config):
|
with override_config(config):
|
||||||
@@ -269,8 +281,16 @@ def benchmark_config(
|
|||||||
|
|
||||||
inplace = not disable_inplace()
|
inplace = not disable_inplace()
|
||||||
if use_deep_gemm:
|
if use_deep_gemm:
|
||||||
return deep_gemm_experts(
|
return deep_gemm_experts.apply(
|
||||||
x, w1, w2, topk_weights, topk_ids, inplace=inplace
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
activation=MoEActivation.SILU,
|
||||||
|
global_num_experts=num_experts,
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
|
expert_map=False,
|
||||||
)
|
)
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ First start serving your model
|
|||||||
```bash
|
```bash
|
||||||
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||||
|
|
||||||
vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests
|
vllm serve $MODEL_PATH --served-model-name Llama
|
||||||
```
|
```
|
||||||
|
|
||||||
The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
|
The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
|
||||||
|
|||||||
@@ -13,28 +13,16 @@ endif()
|
|||||||
#
|
#
|
||||||
# Define environment variables for special configurations
|
# Define environment variables for special configurations
|
||||||
#
|
#
|
||||||
set(ENABLE_AVX2 $ENV{VLLM_CPU_AVX2})
|
set(ENABLE_X86_ISA $ENV{VLLM_CPU_X86})
|
||||||
set(ENABLE_AVX512 $ENV{VLLM_CPU_AVX512})
|
|
||||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
|
||||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
|
||||||
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
|
|
||||||
set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16})
|
set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16})
|
||||||
|
|
||||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||||
|
|
||||||
|
|
||||||
set (ENABLE_NUMA TRUE)
|
set (ENABLE_NUMA TRUE)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Check the compile flags
|
# Check the compile flags
|
||||||
#
|
#
|
||||||
|
|
||||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
|
|
||||||
list(APPEND CXX_COMPILE_FLAGS
|
|
||||||
"-mf16c"
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MACOSX_FOUND)
|
if(MACOSX_FOUND)
|
||||||
list(APPEND CXX_COMPILE_FLAGS
|
list(APPEND CXX_COMPILE_FLAGS
|
||||||
"-DVLLM_CPU_EXTENSION")
|
"-DVLLM_CPU_EXTENSION")
|
||||||
@@ -78,18 +66,6 @@ function(check_sysctl TARGET OUT)
|
|||||||
endif()
|
endif()
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
|
|
||||||
function (is_avx512_disabled OUT)
|
|
||||||
set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512})
|
|
||||||
if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true")
|
|
||||||
set(${OUT} ON PARENT_SCOPE)
|
|
||||||
else()
|
|
||||||
set(${OUT} OFF PARENT_SCOPE)
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
is_avx512_disabled(AVX512_DISABLED)
|
|
||||||
|
|
||||||
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||||
message(STATUS "Apple Silicon Detected")
|
message(STATUS "Apple Silicon Detected")
|
||||||
set(APPLE_SILICON_FOUND TRUE)
|
set(APPLE_SILICON_FOUND TRUE)
|
||||||
@@ -97,8 +73,6 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
|||||||
check_sysctl(hw.optional.neon ASIMD_FOUND)
|
check_sysctl(hw.optional.neon ASIMD_FOUND)
|
||||||
check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND)
|
check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND)
|
||||||
else()
|
else()
|
||||||
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
|
|
||||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
|
||||||
find_isa(${CPUINFO} "Power11" POWER11_FOUND)
|
find_isa(${CPUINFO} "Power11" POWER11_FOUND)
|
||||||
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
|
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
|
||||||
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
|
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
|
||||||
@@ -108,77 +82,32 @@ else()
|
|||||||
find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support
|
find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support
|
||||||
|
|
||||||
# Support cross-compilation by allowing override via environment variables
|
# Support cross-compilation by allowing override via environment variables
|
||||||
if (ENABLE_AVX2)
|
|
||||||
set(AVX2_FOUND ON)
|
|
||||||
message(STATUS "AVX2 support enabled via VLLM_CPU_AVX2 environment variable")
|
|
||||||
endif()
|
|
||||||
if (ENABLE_AVX512)
|
|
||||||
set(AVX512_FOUND ON)
|
|
||||||
message(STATUS "AVX512 support enabled via VLLM_CPU_AVX512 environment variable")
|
|
||||||
endif()
|
|
||||||
if (ENABLE_ARM_BF16)
|
if (ENABLE_ARM_BF16)
|
||||||
set(ARM_BF16_FOUND ON)
|
set(ARM_BF16_FOUND ON)
|
||||||
message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable")
|
message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64" OR ENABLE_X86_ISA)
|
||||||
list(APPEND CXX_COMPILE_FLAGS
|
set(ENABLE_X86_ISA ON)
|
||||||
|
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||||
|
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3))
|
||||||
|
message(FATAL_ERROR "X86 backend requires gcc/g++ >= 12.3")
|
||||||
|
endif()
|
||||||
|
list(APPEND CXX_COMPILE_FLAGS "-mf16c")
|
||||||
|
list(APPEND CXX_COMPILE_FLAGS_AVX512 ${CXX_COMPILE_FLAGS})
|
||||||
|
list(APPEND CXX_COMPILE_FLAGS_AVX2 ${CXX_COMPILE_FLAGS})
|
||||||
|
list(APPEND CXX_COMPILE_FLAGS_AVX512
|
||||||
"-mavx512f"
|
"-mavx512f"
|
||||||
"-mavx512vl"
|
"-mavx512vl"
|
||||||
"-mavx512bw"
|
"-mavx512bw"
|
||||||
"-mavx512dq")
|
"-mavx512dq"
|
||||||
|
"-mavx512bf16"
|
||||||
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
|
"-mavx512vnni"
|
||||||
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
|
"-mamx-bf16"
|
||||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
"-mamx-tile")
|
||||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
list(APPEND CXX_COMPILE_FLAGS_AVX2
|
||||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
"-mavx2")
|
||||||
set(ENABLE_AVX512BF16 ON)
|
|
||||||
else()
|
|
||||||
set(ENABLE_AVX512BF16 OFF)
|
|
||||||
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
set(ENABLE_AVX512BF16 OFF)
|
|
||||||
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
|
|
||||||
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
|
|
||||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
|
||||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
|
||||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
|
||||||
set(ENABLE_AVX512VNNI ON)
|
|
||||||
else()
|
|
||||||
set(ENABLE_AVX512VNNI OFF)
|
|
||||||
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
set(ENABLE_AVX512VNNI OFF)
|
|
||||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
|
|
||||||
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
|
|
||||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
|
||||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
|
||||||
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
|
|
||||||
set(ENABLE_AMXBF16 ON)
|
|
||||||
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
|
|
||||||
else()
|
|
||||||
set(ENABLE_AMXBF16 OFF)
|
|
||||||
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
set(ENABLE_AMXBF16 OFF)
|
|
||||||
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
elseif (AVX2_FOUND)
|
|
||||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
|
||||||
message(WARNING "vLLM CPU backend using AVX2 ISA")
|
|
||||||
|
|
||||||
elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||||
message(STATUS "PowerPC detected")
|
message(STATUS "PowerPC detected")
|
||||||
if (POWER9_FOUND)
|
if (POWER9_FOUND)
|
||||||
@@ -219,12 +148,12 @@ elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64")
|
|||||||
list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc")
|
list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
|
message(FATAL_ERROR "vLLM CPU backend requires X86, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms)
|
# Build oneDNN for GEMM kernels
|
||||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
if (ENABLE_X86_ISA OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||||
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
|
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
|
||||||
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
|
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
|
||||||
set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
|
set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
|
||||||
@@ -329,13 +258,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
|||||||
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
||||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||||
set(ONEDNN_BUILD_GRAPH "OFF")
|
set(ONEDNN_BUILD_GRAPH "OFF")
|
||||||
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
|
set(ONEDNN_ENABLE_JIT_PROFILING "ON")
|
||||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
set(ONEDNN_ENABLE_MAX_CPU_ISA "ON")
|
||||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
set(ONEDNN_ENABLE_CPU_ISA_HINTS "ON")
|
||||||
set(ONEDNN_VERBOSE "OFF")
|
set(ONEDNN_VERBOSE "ON")
|
||||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||||
|
|
||||||
|
# TODO: Refactor this
|
||||||
|
if (ENABLE_X86_ISA)
|
||||||
|
# Note: only enable oneDNN for AVX512
|
||||||
|
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512})
|
||||||
|
else()
|
||||||
|
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS})
|
||||||
|
endif()
|
||||||
|
|
||||||
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||||
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
|
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
|
||||||
FetchContent_MakeAvailable(oneDNN)
|
FetchContent_MakeAvailable(oneDNN)
|
||||||
@@ -348,14 +285,20 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
|||||||
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
||||||
)
|
)
|
||||||
target_link_libraries(dnnl_ext dnnl torch)
|
target_link_libraries(dnnl_ext dnnl torch)
|
||||||
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
|
target_compile_options(dnnl_ext PRIVATE ${DNNL_COMPILE_FLAGS} -fPIC)
|
||||||
list(APPEND LIBS dnnl_ext)
|
list(APPEND LIBS dnnl_ext)
|
||||||
set(USE_ONEDNN ON)
|
set(USE_ONEDNN ON)
|
||||||
else()
|
else()
|
||||||
set(USE_ONEDNN OFF)
|
set(USE_ONEDNN OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
# TODO: Refactor this
|
||||||
|
if (ENABLE_X86_ISA)
|
||||||
|
message(STATUS "CPU extension (AVX512) compile flags: ${CXX_COMPILE_FLAGS_AVX512}")
|
||||||
|
message(STATUS "CPU extension (AVX2) compile flags: ${CXX_COMPILE_FLAGS_AVX2}")
|
||||||
|
else()
|
||||||
|
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
if(ENABLE_NUMA)
|
if(ENABLE_NUMA)
|
||||||
list(APPEND LIBS numa)
|
list(APPEND LIBS numa)
|
||||||
@@ -390,25 +333,6 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/cpu/cpu_attn.cpp"
|
"csrc/cpu/cpu_attn.cpp"
|
||||||
"csrc/cpu/torch_bindings.cpp")
|
"csrc/cpu/torch_bindings.cpp")
|
||||||
|
|
||||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
|
||||||
set(VLLM_EXT_SRC
|
|
||||||
"csrc/cpu/shm.cpp"
|
|
||||||
"csrc/cpu/cpu_wna16.cpp"
|
|
||||||
"csrc/cpu/cpu_fused_moe.cpp"
|
|
||||||
${VLLM_EXT_SRC})
|
|
||||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
|
||||||
set(VLLM_EXT_SRC
|
|
||||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
|
||||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
|
||||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
|
||||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
|
||||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
|
||||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
|
||||||
${VLLM_EXT_SRC})
|
|
||||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
|
if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
"csrc/cpu/shm.cpp"
|
"csrc/cpu/shm.cpp"
|
||||||
@@ -421,21 +345,83 @@ if(USE_ONEDNN)
|
|||||||
${VLLM_EXT_SRC})
|
${VLLM_EXT_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
if (ENABLE_X86_ISA)
|
||||||
|
set(VLLM_EXT_SRC_AVX512
|
||||||
|
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||||
|
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||||
|
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||||
|
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||||
|
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||||
|
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
||||||
|
"csrc/cpu/shm.cpp"
|
||||||
|
"csrc/cpu/cpu_wna16.cpp"
|
||||||
|
"csrc/cpu/cpu_fused_moe.cpp"
|
||||||
|
"csrc/cpu/utils.cpp"
|
||||||
|
"csrc/cpu/cpu_attn.cpp"
|
||||||
|
"csrc/cpu/dnnl_kernels.cpp"
|
||||||
|
"csrc/cpu/torch_bindings.cpp"
|
||||||
|
# TODO: Remove these files
|
||||||
|
"csrc/cpu/activation.cpp"
|
||||||
|
"csrc/cpu/layernorm.cpp"
|
||||||
|
"csrc/cpu/mla_decode.cpp"
|
||||||
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
|
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||||
|
|
||||||
#
|
set(VLLM_EXT_SRC_AVX2
|
||||||
# Define extension targets
|
"csrc/cpu/utils.cpp"
|
||||||
#
|
"csrc/cpu/cpu_attn.cpp"
|
||||||
|
"csrc/cpu/torch_bindings.cpp"
|
||||||
|
# TODO: Remove these files
|
||||||
|
"csrc/cpu/activation.cpp"
|
||||||
|
"csrc/cpu/layernorm.cpp"
|
||||||
|
"csrc/cpu/mla_decode.cpp"
|
||||||
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
|
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||||
|
|
||||||
define_extension_target(
|
message(STATUS "CPU extension (AVX512) source files: ${VLLM_EXT_SRC_AVX512}")
|
||||||
_C
|
message(STATUS "CPU extension (AVX2) source files: ${VLLM_EXT_SRC_AVX2}")
|
||||||
DESTINATION vllm
|
|
||||||
LANGUAGE CXX
|
define_extension_target(
|
||||||
SOURCES ${VLLM_EXT_SRC}
|
_C
|
||||||
LIBRARIES ${LIBS}
|
DESTINATION vllm
|
||||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
LANGUAGE CXX
|
||||||
USE_SABI 3
|
SOURCES ${VLLM_EXT_SRC_AVX512}
|
||||||
WITH_SOABI
|
LIBRARIES ${LIBS}
|
||||||
)
|
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512}
|
||||||
|
USE_SABI 3
|
||||||
|
WITH_SOABI
|
||||||
|
)
|
||||||
|
|
||||||
|
# For SGL kernels
|
||||||
|
target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AVX512")
|
||||||
|
# For AMX kernels
|
||||||
|
target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AMXBF16")
|
||||||
|
|
||||||
|
define_extension_target(
|
||||||
|
_C_AVX2
|
||||||
|
DESTINATION vllm
|
||||||
|
LANGUAGE CXX
|
||||||
|
SOURCES ${VLLM_EXT_SRC_AVX2}
|
||||||
|
LIBRARIES ${LIBS}
|
||||||
|
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX2}
|
||||||
|
USE_SABI 3
|
||||||
|
WITH_SOABI
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||||
|
#
|
||||||
|
# Define extension targets
|
||||||
|
#
|
||||||
|
define_extension_target(
|
||||||
|
_C
|
||||||
|
DESTINATION vllm
|
||||||
|
LANGUAGE CXX
|
||||||
|
SOURCES ${VLLM_EXT_SRC}
|
||||||
|
LIBRARIES ${LIBS}
|
||||||
|
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||||
|
USE_SABI 3
|
||||||
|
WITH_SOABI
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
message(STATUS "Enabling C extension.")
|
message(STATUS "Enabling C extension.")
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ endif()
|
|||||||
# They should be identical but if they aren't, this is a massive footgun.
|
# They should be identical but if they aren't, this is a massive footgun.
|
||||||
#
|
#
|
||||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3),
|
||||||
|
# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files).
|
||||||
# If no component is specified, vllm-flash-attn is still installed.
|
# If no component is specified, vllm-flash-attn is still installed.
|
||||||
|
|
||||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||||
@@ -38,22 +39,16 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6
|
GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
# Ensure the vllm/vllm_flash_attn directory exists before installation
|
|
||||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS)
|
|
||||||
|
|
||||||
# Make sure vllm-flash-attn install rules are nested under vllm/
|
# Make sure vllm-flash-attn install rules are nested under vllm/
|
||||||
# This is here to support installing all components under the same prefix with cmake --install.
|
# ALL_COMPONENTS ensures the save/modify/restore runs exactly once regardless
|
||||||
# setup.py installs every component separately but uses the same prefix for all.
|
# of how many components are being installed, avoiding double-append of /vllm/.
|
||||||
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3,
|
|
||||||
# and these statements don't hurt when installing neither component.
|
|
||||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
|
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
|
||||||
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
|
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
|
||||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
|
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
|
||||||
@@ -62,22 +57,48 @@ install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_
|
|||||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||||
|
|
||||||
# Restore the install prefix
|
# Restore the install prefix after FA's install rules
|
||||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
|
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
|
||||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||||
|
|
||||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
# Install shared Python files for both FA2 and FA3 components
|
||||||
# case only one is built, in the case both are built redundant work is done)
|
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
|
||||||
install(
|
# Ensure the vllm/vllm_flash_attn directory exists before installation
|
||||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")"
|
||||||
DESTINATION vllm/vllm_flash_attn
|
COMPONENT ${_FA_COMPONENT})
|
||||||
COMPONENT _vllm_fa2_C
|
|
||||||
FILES_MATCHING PATTERN "*.py"
|
|
||||||
)
|
|
||||||
|
|
||||||
install(
|
# Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py
|
||||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
# which are source-controlled in vllm)
|
||||||
DESTINATION vllm/vllm_flash_attn
|
install(
|
||||||
COMPONENT _vllm_fa3_C
|
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||||
FILES_MATCHING PATTERN "*.py"
|
DESTINATION vllm/vllm_flash_attn
|
||||||
)
|
COMPONENT ${_FA_COMPONENT}
|
||||||
|
FILES_MATCHING PATTERN "*.py"
|
||||||
|
PATTERN "__init__.py" EXCLUDE
|
||||||
|
PATTERN "flash_attn_interface.py" EXCLUDE
|
||||||
|
)
|
||||||
|
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
#
|
||||||
|
# FA4 CuteDSL component
|
||||||
|
# This is a Python-only component that copies the flash_attn/cute directory
|
||||||
|
# and transforms imports to match our package structure.
|
||||||
|
#
|
||||||
|
add_custom_target(_vllm_fa4_cutedsl_C)
|
||||||
|
|
||||||
|
# Copy flash_attn/cute directory (needed for FA4) and transform imports
|
||||||
|
# The cute directory uses flash_attn.cute imports internally, which we replace
|
||||||
|
# with vllm.vllm_flash_attn.cute to match our package structure.
|
||||||
|
install(CODE "
|
||||||
|
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
|
||||||
|
foreach(SRC_FILE \${CUTE_PY_FILES})
|
||||||
|
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
|
||||||
|
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
|
||||||
|
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
|
||||||
|
file(MAKE_DIRECTORY \${DST_DIR})
|
||||||
|
file(READ \${SRC_FILE} FILE_CONTENTS)
|
||||||
|
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
|
||||||
|
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
|
||||||
|
endforeach()
|
||||||
|
" COMPONENT _vllm_fa4_cutedsl_C)
|
||||||
|
|||||||
@@ -5,117 +5,11 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
|
#include "cuda_vec_utils.cuh"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
struct alignas(32) u32x8_t {
|
|
||||||
uint32_t u0, u1, u2, u3, u4, u5, u6, u7;
|
|
||||||
};
|
|
||||||
|
|
||||||
__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) {
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \
|
|
||||||
defined(CUDA_VERSION) && CUDA_VERSION >= 12090
|
|
||||||
asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n"
|
|
||||||
: "=r"(val.u0), "=r"(val.u1), "=r"(val.u2), "=r"(val.u3),
|
|
||||||
"=r"(val.u4), "=r"(val.u5), "=r"(val.u6), "=r"(val.u7)
|
|
||||||
: "l"(ptr));
|
|
||||||
#else
|
|
||||||
const uint4* uint_ptr = reinterpret_cast<const uint4*>(ptr);
|
|
||||||
uint4 top_half = __ldg(&uint_ptr[0]);
|
|
||||||
uint4 bottom_half = __ldg(&uint_ptr[1]);
|
|
||||||
val.u0 = top_half.x;
|
|
||||||
val.u1 = top_half.y;
|
|
||||||
val.u2 = top_half.z;
|
|
||||||
val.u3 = top_half.w;
|
|
||||||
val.u4 = bottom_half.x;
|
|
||||||
val.u5 = bottom_half.y;
|
|
||||||
val.u6 = bottom_half.z;
|
|
||||||
val.u7 = bottom_half.w;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) {
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \
|
|
||||||
defined(CUDA_VERSION) && CUDA_VERSION >= 12090
|
|
||||||
asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n"
|
|
||||||
:
|
|
||||||
: "l"(ptr), "r"(val.u0), "r"(val.u1), "r"(val.u2), "r"(val.u3),
|
|
||||||
"r"(val.u4), "r"(val.u5), "r"(val.u6), "r"(val.u7)
|
|
||||||
: "memory");
|
|
||||||
#else
|
|
||||||
uint4* uint_ptr = reinterpret_cast<uint4*>(ptr);
|
|
||||||
uint_ptr[0] = make_uint4(val.u0, val.u1, val.u2, val.u3);
|
|
||||||
uint_ptr[1] = make_uint4(val.u4, val.u5, val.u6, val.u7);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool support_256>
|
|
||||||
struct VecTraits;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTraits<true> {
|
|
||||||
static constexpr int ARCH_MAX_VEC_SIZE = 32;
|
|
||||||
using vec_t = u32x8_t;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTraits<false> {
|
|
||||||
static constexpr int ARCH_MAX_VEC_SIZE = 16;
|
|
||||||
using vec_t = int4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct PackedTraits;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PackedTraits<c10::BFloat16> {
|
|
||||||
using packed_t = __nv_bfloat162;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PackedTraits<c10::Half> {
|
|
||||||
using packed_t = __half2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PackedTraits<float> {
|
|
||||||
using packed_t = float2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename packed_t>
|
|
||||||
__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) {
|
|
||||||
if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) {
|
|
||||||
return __bfloat1622float2(val);
|
|
||||||
} else if constexpr (std::is_same_v<packed_t, __half2>) {
|
|
||||||
return __half22float2(val);
|
|
||||||
} else if constexpr (std::is_same_v<packed_t, float2>) {
|
|
||||||
return float2(val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename packed_t>
|
|
||||||
__device__ __forceinline__ packed_t cast_to_packed(const float2& val) {
|
|
||||||
if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) {
|
|
||||||
return __float22bfloat162_rn(val);
|
|
||||||
} else if constexpr (std::is_same_v<packed_t, __half2>) {
|
|
||||||
return __float22half2_rn(val);
|
|
||||||
} else if constexpr (std::is_same_v<packed_t, float2>) {
|
|
||||||
return float2(val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename packed_t>
|
|
||||||
__device__ __forceinline__ packed_t packed_mul(const packed_t& x,
|
|
||||||
const packed_t& y) {
|
|
||||||
if constexpr (std::is_same_v<packed_t, __nv_bfloat162> ||
|
|
||||||
std::is_same_v<packed_t, __half2>) {
|
|
||||||
return __hmul2(x, y);
|
|
||||||
} else if constexpr (std::is_same_v<packed_t, float2>) {
|
|
||||||
return make_float2(x.x * y.x, x.y * y.y);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||||
bool act_first>
|
bool act_first>
|
||||||
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
||||||
@@ -131,16 +25,6 @@ __device__ __forceinline__ packed_t packed_compute(const packed_t& x,
|
|||||||
: packed_mul(x, PACKED_ACT_FN(y));
|
: packed_mul(x, PACKED_ACT_FN(y));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if all pointers are 16-byte aligned for int4 vectorized access
|
|
||||||
__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
|
|
||||||
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if all pointers are 16-byte aligned for longlong4_32a vectorized access
|
|
||||||
__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) {
|
|
||||||
return (reinterpret_cast<uintptr_t>(ptr) & 31) == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Activation and gating kernel template.
|
// Activation and gating kernel template.
|
||||||
template <typename scalar_t, typename packed_t,
|
template <typename scalar_t, typename packed_t,
|
||||||
scalar_t (*ACT_FN)(const scalar_t&),
|
scalar_t (*ACT_FN)(const scalar_t&),
|
||||||
@@ -155,36 +39,32 @@ __global__ void act_and_mul_kernel(
|
|||||||
scalar_t* out_ptr = out + blockIdx.x * d;
|
scalar_t* out_ptr = out + blockIdx.x * d;
|
||||||
|
|
||||||
if constexpr (use_vec) {
|
if constexpr (use_vec) {
|
||||||
// Fast path: 128-bit/256-bit vectorized loop
|
using cuda_t = typename CUDATypeConverter<scalar_t>::Type;
|
||||||
using vec_t = typename VecTraits<use_256b>::vec_t;
|
using pvec_t = PackedVec<cuda_t, use_256b>;
|
||||||
constexpr int ARCH_MAX_VEC_SIZE = VecTraits<use_256b>::ARCH_MAX_VEC_SIZE;
|
|
||||||
constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t);
|
|
||||||
|
|
||||||
const vec_t* x_vec = reinterpret_cast<const vec_t*>(x_ptr);
|
const pvec_t* x_vec = reinterpret_cast<const pvec_t*>(x_ptr);
|
||||||
const vec_t* y_vec = reinterpret_cast<const vec_t*>(y_ptr);
|
const pvec_t* y_vec = reinterpret_cast<const pvec_t*>(y_ptr);
|
||||||
vec_t* out_vec = reinterpret_cast<vec_t*>(out_ptr);
|
pvec_t* out_vec = reinterpret_cast<pvec_t*>(out_ptr);
|
||||||
const int num_vecs = d / 2 / VEC_SIZE;
|
const int num_vecs = d / 2 / pvec_t::NUM_ELTS;
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||||
vec_t x, y;
|
pvec_t x, y;
|
||||||
if constexpr (use_256b) {
|
if constexpr (use_256b) {
|
||||||
ld256(x, &x_vec[i]);
|
ld256(x, &x_vec[i]);
|
||||||
ld256(y, &y_vec[i]);
|
ld256(y, &y_vec[i]);
|
||||||
} else {
|
} else {
|
||||||
x = VLLM_LDG(&x_vec[i]);
|
ld128(x, &x_vec[i]);
|
||||||
y = VLLM_LDG(&y_vec[i]);
|
ld128(y, &y_vec[i]);
|
||||||
}
|
}
|
||||||
auto* xp = reinterpret_cast<packed_t*>(&x);
|
|
||||||
auto* yp = reinterpret_cast<packed_t*>(&y);
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < VEC_SIZE; j++) {
|
for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
|
||||||
xp[j] =
|
x.elts[j] = packed_compute<packed_t, PACKED_ACT_FN, act_first>(
|
||||||
packed_compute<packed_t, PACKED_ACT_FN, act_first>(xp[j], yp[j]);
|
x.elts[j], y.elts[j]);
|
||||||
}
|
}
|
||||||
if constexpr (use_256b) {
|
if constexpr (use_256b) {
|
||||||
st256(x, &out_vec[i]);
|
st256(x, &out_vec[i]);
|
||||||
} else {
|
} else {
|
||||||
out_vec[i] = x;
|
st128(x, &out_vec[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -272,51 +152,54 @@ packed_gelu_tanh_kernel(const packed_t& val) {
|
|||||||
// Launch activation and gating kernel.
|
// Launch activation and gating kernel.
|
||||||
// Use ACT_FIRST (bool) indicating whether to apply the activation function
|
// Use ACT_FIRST (bool) indicating whether to apply the activation function
|
||||||
// first.
|
// first.
|
||||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \
|
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \
|
||||||
auto dtype = input.scalar_type(); \
|
auto dtype = input.scalar_type(); \
|
||||||
int d = input.size(-1) / 2; \
|
int d = input.size(-1) / 2; \
|
||||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||||
if (num_tokens == 0) { \
|
if (num_tokens == 0) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
||||||
int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \
|
int support_vec = \
|
||||||
int vec_size = support_vec / at::elementSize(dtype); \
|
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
|
||||||
const bool use_vec = (d % vec_size == 0); \
|
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
int vec_size = support_vec / at::elementSize(dtype); \
|
||||||
if (use_vec) { \
|
const bool use_vec = (d % vec_size == 0); \
|
||||||
dim3 block(std::min(d / vec_size, 1024)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
if (cc_major >= 10 && num_tokens > 128) { \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
if (use_vec) { \
|
||||||
vllm::act_and_mul_kernel< \
|
dim3 block(std::min(d / vec_size, 1024)); \
|
||||||
scalar_t, typename vllm::PackedTraits<scalar_t>::packed_t, \
|
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
|
||||||
KERNEL<scalar_t>, \
|
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||||
PACKED_KERNEL<typename vllm::PackedTraits<scalar_t>::packed_t>, \
|
vllm::act_and_mul_kernel< \
|
||||||
ACT_FIRST, true, true><<<grid, block, 0, stream>>>( \
|
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
|
KERNEL<scalar_t>, \
|
||||||
}); \
|
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||||
} else { \
|
ACT_FIRST, true, true><<<grid, block, 0, stream>>>( \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
|
||||||
vllm::act_and_mul_kernel< \
|
}); \
|
||||||
scalar_t, typename vllm::PackedTraits<scalar_t>::packed_t, \
|
} else { \
|
||||||
KERNEL<scalar_t>, \
|
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||||
PACKED_KERNEL<typename vllm::PackedTraits<scalar_t>::packed_t>, \
|
vllm::act_and_mul_kernel< \
|
||||||
ACT_FIRST, true, false><<<grid, block, 0, stream>>>( \
|
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
|
KERNEL<scalar_t>, \
|
||||||
}); \
|
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||||
} \
|
ACT_FIRST, true, false><<<grid, block, 0, stream>>>( \
|
||||||
} else { \
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
|
||||||
dim3 block(std::min(d, 1024)); \
|
}); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
} \
|
||||||
vllm::act_and_mul_kernel< \
|
} else { \
|
||||||
scalar_t, typename vllm::PackedTraits<scalar_t>::packed_t, \
|
dim3 block(std::min(d, 1024)); \
|
||||||
KERNEL<scalar_t>, \
|
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
|
||||||
PACKED_KERNEL<typename vllm::PackedTraits<scalar_t>::packed_t>, \
|
vllm::act_and_mul_kernel< \
|
||||||
ACT_FIRST, false><<<grid, block, 0, stream>>>( \
|
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
|
KERNEL<scalar_t>, \
|
||||||
}); \
|
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||||
|
ACT_FIRST, false><<<grid, block, 0, stream>>>( \
|
||||||
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
|
||||||
|
}); \
|
||||||
}
|
}
|
||||||
|
|
||||||
void silu_and_mul(torch::Tensor& out, // [..., d]
|
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
@@ -378,35 +261,31 @@ __global__ void act_and_mul_kernel_with_param(
|
|||||||
scalar_t* out_ptr = out + blockIdx.x * d;
|
scalar_t* out_ptr = out + blockIdx.x * d;
|
||||||
|
|
||||||
if constexpr (use_vec) {
|
if constexpr (use_vec) {
|
||||||
// Fast path: 128-bit/256-bit vectorized loop
|
using cuda_t = typename CUDATypeConverter<scalar_t>::Type;
|
||||||
using vec_t = typename VecTraits<use_256b>::vec_t;
|
using pvec_t = PackedVec<cuda_t, use_256b>;
|
||||||
constexpr int ARCH_MAX_VEC_SIZE = VecTraits<use_256b>::ARCH_MAX_VEC_SIZE;
|
|
||||||
constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t);
|
|
||||||
|
|
||||||
const vec_t* x_vec = reinterpret_cast<const vec_t*>(x_ptr);
|
const pvec_t* x_vec = reinterpret_cast<const pvec_t*>(x_ptr);
|
||||||
const vec_t* y_vec = reinterpret_cast<const vec_t*>(y_ptr);
|
const pvec_t* y_vec = reinterpret_cast<const pvec_t*>(y_ptr);
|
||||||
vec_t* out_vec = reinterpret_cast<vec_t*>(out_ptr);
|
pvec_t* out_vec = reinterpret_cast<pvec_t*>(out_ptr);
|
||||||
const int num_vecs = d / 2 / VEC_SIZE;
|
const int num_vecs = d / 2 / pvec_t::NUM_ELTS;
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
|
||||||
vec_t x, y;
|
pvec_t x, y;
|
||||||
if constexpr (use_256b) {
|
if constexpr (use_256b) {
|
||||||
ld256(x, &x_vec[i]);
|
ld256(x, &x_vec[i]);
|
||||||
ld256(y, &y_vec[i]);
|
ld256(y, &y_vec[i]);
|
||||||
} else {
|
} else {
|
||||||
x = VLLM_LDG(&x_vec[i]);
|
ld128(x, &x_vec[i]);
|
||||||
y = VLLM_LDG(&y_vec[i]);
|
ld128(y, &y_vec[i]);
|
||||||
}
|
}
|
||||||
auto* xp = reinterpret_cast<packed_t*>(&x);
|
|
||||||
auto* yp = reinterpret_cast<packed_t*>(&y);
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < VEC_SIZE; j++) {
|
for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
|
||||||
xp[j] = packed_mul(PACKED_ACT_FN(xp[j], param), yp[j]);
|
x.elts[j] = packed_mul(PACKED_ACT_FN(x.elts[j], param), y.elts[j]);
|
||||||
}
|
}
|
||||||
if constexpr (use_256b) {
|
if constexpr (use_256b) {
|
||||||
st256(x, &out_vec[i]);
|
st256(x, &out_vec[i]);
|
||||||
} else {
|
} else {
|
||||||
out_vec[i] = x;
|
st128(x, &out_vec[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -499,21 +378,24 @@ __global__ void swigluoai_and_mul_kernel(
|
|||||||
} \
|
} \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
||||||
int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \
|
int support_vec = \
|
||||||
|
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
|
||||||
|
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
|
||||||
|
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
|
||||||
int vec_size = support_vec / at::elementSize(dtype); \
|
int vec_size = support_vec / at::elementSize(dtype); \
|
||||||
const bool use_vec = (d % vec_size == 0); \
|
const bool use_vec = (d % vec_size == 0); \
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
if (use_vec) { \
|
if (use_vec) { \
|
||||||
dim3 block(std::min(d / vec_size, 1024)); \
|
dim3 block(std::min(d / vec_size, 1024)); \
|
||||||
if (cc_major >= 10 && num_tokens > 128) { \
|
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
dtype, "act_and_mul_kernel_with_param", [&] { \
|
dtype, "act_and_mul_kernel_with_param", [&] { \
|
||||||
vllm::act_and_mul_kernel_with_param< \
|
vllm::act_and_mul_kernel_with_param< \
|
||||||
scalar_t, typename vllm::PackedTraits<scalar_t>::packed_t, \
|
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||||
KERNEL<scalar_t>, \
|
KERNEL<scalar_t>, \
|
||||||
PACKED_KERNEL< \
|
PACKED_KERNEL< \
|
||||||
typename vllm::PackedTraits<scalar_t>::packed_t>, \
|
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||||
true, true><<<grid, block, 0, stream>>>( \
|
true, true><<<grid, block, 0, stream>>>( \
|
||||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
|
||||||
PARAM); \
|
PARAM); \
|
||||||
@@ -522,10 +404,10 @@ __global__ void swigluoai_and_mul_kernel(
|
|||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
dtype, "act_and_mul_kernel_with_param", [&] { \
|
dtype, "act_and_mul_kernel_with_param", [&] { \
|
||||||
vllm::act_and_mul_kernel_with_param< \
|
vllm::act_and_mul_kernel_with_param< \
|
||||||
scalar_t, typename vllm::PackedTraits<scalar_t>::packed_t, \
|
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||||
KERNEL<scalar_t>, \
|
KERNEL<scalar_t>, \
|
||||||
PACKED_KERNEL< \
|
PACKED_KERNEL< \
|
||||||
typename vllm::PackedTraits<scalar_t>::packed_t>, \
|
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||||
true, false><<<grid, block, 0, stream>>>( \
|
true, false><<<grid, block, 0, stream>>>( \
|
||||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
|
||||||
PARAM); \
|
PARAM); \
|
||||||
@@ -535,9 +417,9 @@ __global__ void swigluoai_and_mul_kernel(
|
|||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
|
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
|
||||||
vllm::act_and_mul_kernel_with_param< \
|
vllm::act_and_mul_kernel_with_param< \
|
||||||
scalar_t, typename vllm::PackedTraits<scalar_t>::packed_t, \
|
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
|
||||||
KERNEL<scalar_t>, \
|
KERNEL<scalar_t>, \
|
||||||
PACKED_KERNEL<typename vllm::PackedTraits<scalar_t>::packed_t>, \
|
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
|
||||||
false><<<grid, block, 0, stream>>>( \
|
false><<<grid, block, 0, stream>>>( \
|
||||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, PARAM); \
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, PARAM); \
|
||||||
}); \
|
}); \
|
||||||
@@ -629,14 +511,17 @@ __global__ void activation_kernel(
|
|||||||
} \
|
} \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
|
||||||
int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \
|
int support_vec = \
|
||||||
|
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
|
||||||
|
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
|
||||||
|
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
|
||||||
int vec_size = support_vec / at::elementSize(dtype); \
|
int vec_size = support_vec / at::elementSize(dtype); \
|
||||||
const bool use_vec = (d % vec_size == 0); \
|
const bool use_vec = (d % vec_size == 0); \
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
if (use_vec) { \
|
if (use_vec) { \
|
||||||
dim3 block(std::min(d / vec_size, 1024)); \
|
dim3 block(std::min(d / vec_size, 1024)); \
|
||||||
if (cc_major >= 10 && num_tokens > 128) { \
|
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
|
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
|
||||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
|
||||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
|
|||||||
@@ -4,6 +4,10 @@
|
|||||||
|
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
// Note: overwrite the external defination for sharing same name between
|
||||||
|
// libraries use different ISAs.
|
||||||
|
#define TORCH_EXTENSION_NAME _C
|
||||||
|
|
||||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||||
|
|
||||||
void release_dnnl_matmul_handler(int64_t handler);
|
void release_dnnl_matmul_handler(int64_t handler);
|
||||||
@@ -324,19 +328,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"str act, str isa) -> ()");
|
"str act, str isa) -> ()");
|
||||||
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
|
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
|
||||||
#endif
|
#endif
|
||||||
}
|
ops.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||||
|
ops.def(
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
|
||||||
// CPU utils
|
|
||||||
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
|
|
||||||
cpu_ops.def(
|
|
||||||
"mla_decode_kvcache("
|
"mla_decode_kvcache("
|
||||||
" Tensor! out, Tensor query, Tensor kv_cache,"
|
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||||
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
||||||
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
|
|||||||
334
csrc/cuda_vec_utils.cuh
Normal file
334
csrc/cuda_vec_utils.cuh
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/util/BFloat16.h>
|
||||||
|
#include <c10/util/Half.h>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#else
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which
|
||||||
|
// together enable 256-bit (v8.u32) PTX load/store instructions.
|
||||||
|
// Use for PTX instruction selection with architecture fallback paths.
|
||||||
|
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \
|
||||||
|
defined(CUDA_VERSION) && CUDA_VERSION >= 12090
|
||||||
|
#define VLLM_256B_PTX_ENABLED 1
|
||||||
|
#else
|
||||||
|
#define VLLM_256B_PTX_ENABLED 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Types and traits
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// 256-bit (32-byte) aligned vector type: 8 x uint32_t
|
||||||
|
struct alignas(32) u32x8_t {
|
||||||
|
uint32_t d[8];
|
||||||
|
};
|
||||||
|
|
||||||
|
// VecTraits — select between 128-bit (int4) and 256-bit
|
||||||
|
// (u32x8_t) vector types at compile time.
|
||||||
|
template <bool support_256>
|
||||||
|
struct VecTraits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct VecTraits<true> {
|
||||||
|
static constexpr int ARCH_MAX_VEC_SIZE = 32;
|
||||||
|
using vec_t = u32x8_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct VecTraits<false> {
|
||||||
|
static constexpr int ARCH_MAX_VEC_SIZE = 16;
|
||||||
|
using vec_t = int4;
|
||||||
|
};
|
||||||
|
|
||||||
|
// PackedTypeConverter — map between CUDA scalar and packed types
|
||||||
|
// half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc.
|
||||||
|
template <typename T>
|
||||||
|
struct PackedTypeConverter {
|
||||||
|
static_assert(sizeof(T) == 0,
|
||||||
|
"PackedTypeConverter is not specialized for this type.");
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedTypeConverter<half2> {
|
||||||
|
using Type = half;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedTypeConverter<half> {
|
||||||
|
using Type = half2;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedTypeConverter<__nv_bfloat162> {
|
||||||
|
using Type = __nv_bfloat16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedTypeConverter<__nv_bfloat16> {
|
||||||
|
using Type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedTypeConverter<float> {
|
||||||
|
using Type = float2;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedTypeConverter<float2> {
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedTypeConverter<c10::Half> {
|
||||||
|
using Type = half2;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedTypeConverter<c10::BFloat16> {
|
||||||
|
using Type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
|
||||||
|
// CUDATypeConverter — map PyTorch scalar types to CUDA scalar
|
||||||
|
// c10::Half -> half, c10::BFloat16 -> __nv_bfloat16
|
||||||
|
template <typename T>
|
||||||
|
struct CUDATypeConverter {
|
||||||
|
using Type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CUDATypeConverter<c10::Half> {
|
||||||
|
using Type = half;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CUDATypeConverter<c10::BFloat16> {
|
||||||
|
using Type = __nv_bfloat16;
|
||||||
|
};
|
||||||
|
|
||||||
|
// PackedVec — typed vector container for packed element access.
|
||||||
|
// Derives alignment and element count from VecTraits.
|
||||||
|
// Type is the CUDA scalar type (e.g. half, __nv_bfloat16).
|
||||||
|
template <class Type, bool use_256b>
|
||||||
|
struct alignas(VecTraits<use_256b>::ARCH_MAX_VEC_SIZE) PackedVec {
|
||||||
|
static constexpr int NUM_ELTS =
|
||||||
|
VecTraits<use_256b>::ARCH_MAX_VEC_SIZE /
|
||||||
|
sizeof(typename PackedTypeConverter<Type>::Type);
|
||||||
|
typename PackedTypeConverter<Type>::Type elts[NUM_ELTS];
|
||||||
|
};
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Load / store primitives
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// 256-bit load / store — SM100+ only (PTX v8 instructions).
|
||||||
|
__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) {
|
||||||
|
#if VLLM_256B_PTX_ENABLED
|
||||||
|
asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n"
|
||||||
|
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
|
||||||
|
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
|
||||||
|
: "l"(ptr));
|
||||||
|
#else
|
||||||
|
assert(false && "ld256 requires SM100+ with CUDA 12.9+");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) {
|
||||||
|
#if VLLM_256B_PTX_ENABLED
|
||||||
|
asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n"
|
||||||
|
:
|
||||||
|
: "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]),
|
||||||
|
"r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]),
|
||||||
|
"r"(val.d[7])
|
||||||
|
: "memory");
|
||||||
|
#else
|
||||||
|
assert(false && "st256 requires SM100+ with CUDA 12.9+");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec).
|
||||||
|
// Non-template overloads above are preferred for u32x8_t.
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void ld256(T& val, const T* ptr) {
|
||||||
|
static_assert(sizeof(T) == 32, "ld256 requires a 32-byte type");
|
||||||
|
ld256(reinterpret_cast<u32x8_t&>(val), reinterpret_cast<const u32x8_t*>(ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void st256(T& val, T* ptr) {
|
||||||
|
static_assert(sizeof(T) == 32, "st256 requires a 32-byte type");
|
||||||
|
st256(reinterpret_cast<u32x8_t&>(val), reinterpret_cast<u32x8_t*>(ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 128-bit load / store via __ldg (read-only cache hint).
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void ld128(T& val, const T* ptr) {
|
||||||
|
static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type");
|
||||||
|
*reinterpret_cast<int4*>(&val) = __ldg(reinterpret_cast<const int4*>(ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void st128(T& val, T* ptr) {
|
||||||
|
static_assert(sizeof(T) == 16, "st128 requires a 16-byte type");
|
||||||
|
*reinterpret_cast<int4*>(ptr) = *reinterpret_cast<int4*>(&val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 256-bit cache-streaming (.cs) load / store — SM100+ only.
|
||||||
|
__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) {
|
||||||
|
#if VLLM_256B_PTX_ENABLED
|
||||||
|
u32x8_t val;
|
||||||
|
asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];"
|
||||||
|
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
|
||||||
|
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
|
||||||
|
: "l"(addr));
|
||||||
|
return val;
|
||||||
|
#else
|
||||||
|
assert(false && "ld256_cs requires SM100+ with CUDA 12.9+");
|
||||||
|
return {};
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) {
|
||||||
|
#if VLLM_256B_PTX_ENABLED
|
||||||
|
asm volatile(
|
||||||
|
"st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr),
|
||||||
|
"r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]),
|
||||||
|
"r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]));
|
||||||
|
#else
|
||||||
|
assert(false && "st256_cs requires SM100+ with CUDA 12.9+");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// 32-bit cache-streaming (.cs) load / store — SM100+ only.
|
||||||
|
__forceinline__ __device__ int ld32_cs(const int* addr) {
|
||||||
|
#if VLLM_256B_PTX_ENABLED
|
||||||
|
int val;
|
||||||
|
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
|
||||||
|
return val;
|
||||||
|
#else
|
||||||
|
assert(false && "ld32_cs requires SM100+ with CUDA 12.9+");
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void st32_cs(int* addr, int val) {
|
||||||
|
#if VLLM_256B_PTX_ENABLED
|
||||||
|
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
|
||||||
|
#else
|
||||||
|
assert(false && "st32_cs requires SM100+ with CUDA 12.9+");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predicated 256-bit / 128-bit cache-global (.cg) loads.
|
||||||
|
// Returns zero if pred is false. SM100+ only.
|
||||||
|
__device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr,
|
||||||
|
bool pred) {
|
||||||
|
#if VLLM_256B_PTX_ENABLED
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" .reg .pred pr;\n"
|
||||||
|
" setp.ne.u32 pr, %8, 0;\n"
|
||||||
|
" mov.u32 %0, 0;\n"
|
||||||
|
" mov.u32 %1, 0;\n"
|
||||||
|
" mov.u32 %2, 0;\n"
|
||||||
|
" mov.u32 %3, 0;\n"
|
||||||
|
" mov.u32 %4, 0;\n"
|
||||||
|
" mov.u32 %5, 0;\n"
|
||||||
|
" mov.u32 %6, 0;\n"
|
||||||
|
" mov.u32 %7, 0;\n"
|
||||||
|
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n"
|
||||||
|
"}\n"
|
||||||
|
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
|
||||||
|
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
|
||||||
|
: "r"((int)pred), "l"(ptr));
|
||||||
|
#else
|
||||||
|
assert(false && "ld256_cg_or_zero requires SM100+ with CUDA 12.9+");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,
|
||||||
|
bool pred) {
|
||||||
|
#if VLLM_256B_PTX_ENABLED
|
||||||
|
uint32_t r0, r1, r2, r3;
|
||||||
|
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" .reg .pred pr;\n"
|
||||||
|
" setp.ne.u32 pr, %4, 0;\n"
|
||||||
|
" mov.u32 %0, 0;\n"
|
||||||
|
" mov.u32 %1, 0;\n"
|
||||||
|
" mov.u32 %2, 0;\n"
|
||||||
|
" mov.u32 %3, 0;\n"
|
||||||
|
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n"
|
||||||
|
"}\n"
|
||||||
|
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
|
||||||
|
: "r"((int)pred), "l"(ptr));
|
||||||
|
|
||||||
|
val = uint4{r0, r1, r2, r3};
|
||||||
|
#else
|
||||||
|
assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Alignment helpers
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
|
||||||
|
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) {
|
||||||
|
return (reinterpret_cast<uintptr_t>(ptr) & 31) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Packed type conversion and arithmetic
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
template <typename packed_t>
|
||||||
|
__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) {
|
||||||
|
if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) {
|
||||||
|
return __bfloat1622float2(val);
|
||||||
|
} else if constexpr (std::is_same_v<packed_t, __half2>) {
|
||||||
|
return __half22float2(val);
|
||||||
|
} else if constexpr (std::is_same_v<packed_t, float2>) {
|
||||||
|
return float2(val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename packed_t>
|
||||||
|
__device__ __forceinline__ packed_t cast_to_packed(const float2& val) {
|
||||||
|
if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) {
|
||||||
|
return __float22bfloat162_rn(val);
|
||||||
|
} else if constexpr (std::is_same_v<packed_t, __half2>) {
|
||||||
|
return __float22half2_rn(val);
|
||||||
|
} else if constexpr (std::is_same_v<packed_t, float2>) {
|
||||||
|
return float2(val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename packed_t>
|
||||||
|
__device__ __forceinline__ packed_t packed_mul(const packed_t& x,
|
||||||
|
const packed_t& y) {
|
||||||
|
if constexpr (std::is_same_v<packed_t, __nv_bfloat162> ||
|
||||||
|
std::is_same_v<packed_t, __half2>) {
|
||||||
|
return __hmul2(x, y);
|
||||||
|
} else if constexpr (std::is_same_v<packed_t, float2>) {
|
||||||
|
return make_float2(x.x * y.x, x.y * y.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -15,9 +15,9 @@
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct SSMParamsBase {
|
struct SSMParamsBase {
|
||||||
using index_t = uint32_t;
|
using index_t = size_t;
|
||||||
|
|
||||||
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
int batch, dim, seqlen, dstate, n_groups;
|
||||||
int dim_ngroups_ratio;
|
int dim_ngroups_ratio;
|
||||||
bool is_variable_B;
|
bool is_variable_B;
|
||||||
bool is_variable_C;
|
bool is_variable_C;
|
||||||
@@ -72,6 +72,8 @@ struct SSMParamsBase {
|
|||||||
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
|
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
|
||||||
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
|
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
|
||||||
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
|
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
|
||||||
|
void *__restrict__ cu_chunk_seqlen_ptr; // (nchunks+1,) - cumulative chunk token offsets
|
||||||
|
void *__restrict__ last_chunk_indices_ptr; // (batch,) - index of last chunk per sequence
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||||
constexpr bool kVarlen = Ktraits::kVarlen;
|
constexpr bool kVarlen = Ktraits::kVarlen;
|
||||||
constexpr int kNThreads = Ktraits::kNThreads;
|
|
||||||
constexpr int kNItems = Ktraits::kNItems;
|
constexpr int kNItems = Ktraits::kNItems;
|
||||||
constexpr int kNRows = Ktraits::kNRows;
|
constexpr int kNRows = Ktraits::kNRows;
|
||||||
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||||
@@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
|
||||||
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
|
||||||
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
|
||||||
// }
|
|
||||||
|
|
||||||
constexpr int kChunkSize = kNThreads * kNItems;
|
|
||||||
|
|
||||||
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
|
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
|
||||||
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
|
const int block_size = params.cache_enabled ? params.block_size : 2048;
|
||||||
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
|
|
||||||
|
|
||||||
const int* batch_cache_indices = cache_indices != nullptr ?
|
const int* batch_cache_indices = cache_indices != nullptr ?
|
||||||
cache_indices + batch_id * params.cache_indices_stride : nullptr;
|
cache_indices + batch_id * params.cache_indices_stride : nullptr;
|
||||||
@@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
|
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
|
||||||
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
|
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
|
||||||
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
|
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
|
||||||
|
const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ?
|
||||||
|
reinterpret_cast<const int*>(params.cu_chunk_seqlen_ptr) : nullptr;
|
||||||
|
const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ?
|
||||||
|
reinterpret_cast<const int*>(params.last_chunk_indices_ptr) : nullptr;
|
||||||
|
|
||||||
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
|
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
|
||||||
|
|
||||||
|
const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ?
|
||||||
|
block_idx_first_scheduled[batch_id] : 0;
|
||||||
|
|
||||||
|
// Determine chunk boundaries from pre-computed metadata (APC mode)
|
||||||
|
// or fall back to simple block_size chunking.
|
||||||
|
int first_chunk_idx, n_chunks;
|
||||||
|
int current_position;
|
||||||
|
|
||||||
|
if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) {
|
||||||
|
const int last_chunk_idx = last_chunk_indices[batch_id];
|
||||||
|
first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1;
|
||||||
|
n_chunks = last_chunk_idx - first_chunk_idx + 1;
|
||||||
|
// Derive current_position: if the first chunk is partial (fills remainder
|
||||||
|
// of a started block), offset into the block accordingly.
|
||||||
|
const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx];
|
||||||
|
const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size)
|
||||||
|
? (block_size - first_chunk_tokens) : 0;
|
||||||
|
current_position = block_idx_first * block_size + chunk_start_offset;
|
||||||
|
} else {
|
||||||
|
first_chunk_idx = 0;
|
||||||
|
n_chunks = (seqlen + block_size - 1) / block_size;
|
||||||
|
current_position = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int tokens_processed = 0;
|
||||||
|
|
||||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||||
|
const int chunk_tokens = (cu_chunk_seqlen != nullptr)
|
||||||
|
? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk]
|
||||||
|
: min(block_size, seqlen - tokens_processed);
|
||||||
|
if (chunk_tokens <= 0) break;
|
||||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
if constexpr (!kDirectIO) {
|
if constexpr (!kDirectIO) {
|
||||||
if (r > 0) { __syncthreads(); }
|
if (r > 0) { __syncthreads(); }
|
||||||
}
|
}
|
||||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
|
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens);
|
||||||
if constexpr (!kDirectIO) { __syncthreads(); }
|
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
|
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens);
|
||||||
}
|
}
|
||||||
u += kChunkSize;
|
u += chunk_tokens;
|
||||||
delta += kChunkSize;
|
delta += chunk_tokens;
|
||||||
|
|
||||||
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||||
if constexpr (kIsVariableB) {
|
if constexpr (kIsVariableB) {
|
||||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||||
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
|
smem_load_weight, chunk_tokens);
|
||||||
if constexpr (!kIsVariableC) {
|
if constexpr (!kIsVariableC) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int r = 0; r < kNRows; ++r) {
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
@@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
if constexpr (kIsVariableC) {
|
if constexpr (kIsVariableC) {
|
||||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||||
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
|
smem_load_weight_C, chunk_tokens);
|
||||||
if constexpr (!kIsVariableB) {
|
if constexpr (!kIsVariableB) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int r = 0; r < kNRows; ++r) {
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
@@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
for (int i = 0; i < kNItems; ++i) {
|
for (int i = 0; i < kNItems; ++i) {
|
||||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||||
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
|
if (threadIdx.x * kNItems + i >= chunk_tokens) {
|
||||||
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
|
thread_data[i] = make_float2(1.f, 0.f);
|
||||||
thread_data[i] = make_float2(1.f, 0.f);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Initialize running total
|
// Initialize running total
|
||||||
@@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
|
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
|
||||||
|
|
||||||
// Store state at the end of each chunk when cache is enabled
|
// Store state at the end of each aligned chunk when cache is enabled
|
||||||
if (params.cache_enabled && batch_cache_indices != nullptr) {
|
if (params.cache_enabled && batch_cache_indices != nullptr) {
|
||||||
|
|
||||||
size_t cache_slot;
|
size_t cache_slot;
|
||||||
if (chunk == n_chunks - 1) {
|
if (chunk == n_chunks - 1) {
|
||||||
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
|
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
|
||||||
} else {
|
} else {
|
||||||
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
|
const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size;
|
||||||
|
cache_slot = batch_cache_indices[block_idx_completed];
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
|
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
|
||||||
@@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
+ dim_id * kNRows * params.out_d_stride + tokens_processed;
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int r = 0; r < kNRows; ++r) {
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
if constexpr (!kDirectIO) {
|
if constexpr (!kDirectIO) {
|
||||||
if (r > 0) { __syncthreads(); }
|
if (r > 0) { __syncthreads(); }
|
||||||
}
|
}
|
||||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (kHasZ) {
|
if constexpr (kHasZ) {
|
||||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
|
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
|
||||||
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
+ dim_id * kNRows * params.z_d_stride + tokens_processed;
|
||||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
|
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
|
||||||
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
+ dim_id * kNRows * params.out_z_d_stride + tokens_processed;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int r = 0; r < kNRows; ++r) {
|
for (int r = 0; r < kNRows; ++r) {
|
||||||
input_t z_vals[kNItems];
|
input_t z_vals[kNItems];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
|
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < kNItems; ++i) {
|
for (int i = 0; i < kNItems; ++i) {
|
||||||
float z_val = z_vals[i];
|
float z_val = z_vals[i];
|
||||||
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Bvar += kChunkSize * 1;
|
Bvar += chunk_tokens;
|
||||||
Cvar += kChunkSize * 1;
|
Cvar += chunk_tokens;
|
||||||
|
|
||||||
|
tokens_processed += chunk_tokens;
|
||||||
|
current_position += chunk_tokens;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &initial_state_idx) {
|
const std::optional<torch::Tensor> &initial_state_idx,
|
||||||
|
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
||||||
|
const std::optional<torch::Tensor> &last_chunk_indices) {
|
||||||
|
|
||||||
// Reset the parameters
|
// Reset the parameters
|
||||||
memset(¶ms, 0, sizeof(params));
|
memset(¶ms, 0, sizeof(params));
|
||||||
@@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
|
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
|
||||||
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
|
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
|
||||||
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
|
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
|
||||||
|
params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr;
|
||||||
|
params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr;
|
||||||
|
|
||||||
// All stride are in elements, not bytes.
|
// All stride are in elements, not bytes.
|
||||||
params.A_d_stride = A.stride(0);
|
params.A_d_stride = A.stride(0);
|
||||||
@@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &initial_state_idx) {
|
const std::optional<torch::Tensor> &initial_state_idx,
|
||||||
|
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
||||||
|
const std::optional<torch::Tensor> &last_chunk_indices) {
|
||||||
auto input_type = u.scalar_type();
|
auto input_type = u.scalar_type();
|
||||||
auto weight_type = A.scalar_type();
|
auto weight_type = A.scalar_type();
|
||||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||||
@@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
block_size,
|
block_size,
|
||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
initial_state_idx
|
initial_state_idx,
|
||||||
|
cu_chunk_seqlen,
|
||||||
|
last_chunk_indices
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,10 @@ void shuffle_rows(const torch::Tensor& input_tensor,
|
|||||||
torch::Tensor& output_tensor);
|
torch::Tensor& output_tensor);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
|
||||||
|
torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
|
||||||
|
torch::Tensor const& weight);
|
||||||
|
|
||||||
// DeepSeek V3 optimized router GEMM kernel for SM90+
|
// DeepSeek V3 optimized router GEMM kernel for SM90+
|
||||||
// Computes output = mat_a @ mat_b.T where:
|
// Computes output = mat_a @ mat_b.T where:
|
||||||
// mat_a: [num_tokens, hidden_dim] in bf16
|
// mat_a: [num_tokens, hidden_dim] in bf16
|
||||||
|
|||||||
60
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu
Normal file
60
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
// Adapted from SGLang:
|
||||||
|
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
|
||||||
|
|
||||||
|
void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& sfa,
|
||||||
|
const torch::Tensor& sfb, torch::Tensor& d,
|
||||||
|
const torch::Tensor& problem_sizes,
|
||||||
|
const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& blockscale_offsets) {
|
||||||
|
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||||
|
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||||
|
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||||
|
"problem_sizes must have shape (num_experts, 3)");
|
||||||
|
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||||
|
"Number of experts in problem_sizes must match expert_offsets");
|
||||||
|
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||||
|
"problem_sizes must be int32");
|
||||||
|
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
|
||||||
|
"expert_offsets must be int32");
|
||||||
|
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
|
||||||
|
"blockscale_offsets must be int32");
|
||||||
|
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
|
||||||
|
TORCH_CHECK(b.dim() == 3,
|
||||||
|
"b must be a 3D tensor of shape (num_experts, k, n)");
|
||||||
|
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
|
||||||
|
"k should align 128");
|
||||||
|
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
|
||||||
|
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
|
||||||
|
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
if (d.dtype() == torch::kBFloat16) {
|
||||||
|
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||||
|
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||||
|
blockscale_offsets, stream);
|
||||||
|
} else if (d.dtype() == torch::kFloat16) {
|
||||||
|
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||||
|
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||||
|
blockscale_offsets, stream);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"No implemented cutlass_mxfp8_grouped_mm for "
|
||||||
|
"current device");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm);
|
||||||
|
}
|
||||||
141
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh
Normal file
141
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
// Adapted from SGLang:
|
||||||
|
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <cuda.h>
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/util/packed_stride.hpp"
|
||||||
|
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
|
||||||
|
|
||||||
|
namespace expert_specialization {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
template <typename GemmTraits>
|
||||||
|
struct CutlassMxfp8GroupedMmOffsetFunctor {
|
||||||
|
using Gemm = typename GemmTraits::Gemm;
|
||||||
|
using ElementA = typename Gemm::ElementA;
|
||||||
|
using ElementB = typename Gemm::ElementB;
|
||||||
|
using ElementSF = typename GemmTraits::ElementSF;
|
||||||
|
using ElementD = typename GemmTraits::ElementOutput;
|
||||||
|
// Input
|
||||||
|
int* expert_offsets{nullptr};
|
||||||
|
int* blockscale_offsets{nullptr};
|
||||||
|
// Output
|
||||||
|
ElementA* a_base{nullptr};
|
||||||
|
ElementB* b_base{nullptr};
|
||||||
|
ElementSF* sfa_base{nullptr};
|
||||||
|
ElementSF* sfb_base{nullptr};
|
||||||
|
ElementD* d_base{nullptr};
|
||||||
|
ElementA** a_offsets{nullptr};
|
||||||
|
ElementB** b_offsets{nullptr};
|
||||||
|
ElementSF** sfa_offsets{nullptr};
|
||||||
|
ElementSF** sfb_offsets{nullptr};
|
||||||
|
ElementD** d_offsets{nullptr};
|
||||||
|
|
||||||
|
CutlassMxfp8GroupedMmOffsetFunctor() = default;
|
||||||
|
CutlassMxfp8GroupedMmOffsetFunctor(
|
||||||
|
int* _expert_offsets, int* _blockscale_offsets, ElementA* _a_base,
|
||||||
|
ElementB* _b_base, ElementSF* _sfa_base, ElementSF* _sfb_base,
|
||||||
|
ElementD* _d_base, ElementA** _a_offsets, ElementB** _b_offsets,
|
||||||
|
ElementSF** _sfa_offsets, ElementSF** _sfb_offsets, ElementD** _d_offsets)
|
||||||
|
: expert_offsets{_expert_offsets},
|
||||||
|
blockscale_offsets{_blockscale_offsets},
|
||||||
|
a_base(_a_base),
|
||||||
|
b_base(_b_base),
|
||||||
|
sfa_base(_sfa_base),
|
||||||
|
sfb_base(_sfb_base),
|
||||||
|
d_base(_d_base),
|
||||||
|
a_offsets(_a_offsets),
|
||||||
|
b_offsets(_b_offsets),
|
||||||
|
sfa_offsets(_sfa_offsets),
|
||||||
|
sfb_offsets(_sfb_offsets),
|
||||||
|
d_offsets(_d_offsets) {}
|
||||||
|
|
||||||
|
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
|
||||||
|
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
|
||||||
|
int64_t blockscale_offset =
|
||||||
|
static_cast<int64_t>(blockscale_offsets[expert_id]);
|
||||||
|
int64_t a_stride = expert_offset * k;
|
||||||
|
int64_t b_stride = expert_id * k * n;
|
||||||
|
int64_t d_stride = expert_offset * n;
|
||||||
|
int64_t sfa_stride = blockscale_offset * (k / 32);
|
||||||
|
int64_t sfb_stride = expert_id * n * (k / 32);
|
||||||
|
|
||||||
|
a_offsets[expert_id] = a_base + a_stride;
|
||||||
|
b_offsets[expert_id] = b_base + b_stride;
|
||||||
|
sfa_offsets[expert_id] = sfa_base + sfa_stride;
|
||||||
|
sfb_offsets[expert_id] = sfb_base + sfb_stride;
|
||||||
|
d_offsets[expert_id] = d_base + d_stride;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename GemmTraits>
|
||||||
|
struct CutlassMxfp8GroupedMmLayoutFunctor {
|
||||||
|
using Sm1xxBlkScaledConfig = typename GemmTraits::Sm1xxBlkScaledConfig;
|
||||||
|
using LayoutSFA = typename GemmTraits::LayoutSFA;
|
||||||
|
using LayoutSFB = typename GemmTraits::LayoutSFB;
|
||||||
|
LayoutSFA* layout_sfa_base{nullptr};
|
||||||
|
LayoutSFB* layout_sfb_base{nullptr};
|
||||||
|
|
||||||
|
CutlassMxfp8GroupedMmLayoutFunctor() = default;
|
||||||
|
CutlassMxfp8GroupedMmLayoutFunctor(LayoutSFA* _layout_sfa_base,
|
||||||
|
LayoutSFB* _layout_sfb_base)
|
||||||
|
: layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}
|
||||||
|
|
||||||
|
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
|
||||||
|
LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
|
||||||
|
LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
|
||||||
|
*layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
|
||||||
|
cute::make_shape(m, n, k, 1));
|
||||||
|
*layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
|
||||||
|
cute::make_shape(m, n, k, 1));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename GemmTraits>
|
||||||
|
struct CutlassMxfp8GroupedMmStrideFunctor {
|
||||||
|
using StrideA = typename GemmTraits::StrideA;
|
||||||
|
using StrideB = typename GemmTraits::StrideB;
|
||||||
|
using StrideD = typename GemmTraits::StrideD;
|
||||||
|
StrideA* stride_A_base{nullptr};
|
||||||
|
StrideB* stride_B_base{nullptr};
|
||||||
|
StrideD* stride_D_base{nullptr};
|
||||||
|
|
||||||
|
CutlassMxfp8GroupedMmStrideFunctor() = default;
|
||||||
|
CutlassMxfp8GroupedMmStrideFunctor(StrideA* _stride_A_base,
|
||||||
|
StrideB* _stride_B_base,
|
||||||
|
StrideD* _stride_D_base)
|
||||||
|
: stride_A_base(_stride_A_base),
|
||||||
|
stride_B_base(_stride_B_base),
|
||||||
|
stride_D_base(_stride_D_base) {}
|
||||||
|
|
||||||
|
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
|
||||||
|
StrideA* stride_A = stride_A_base + expert_id;
|
||||||
|
StrideB* stride_B = stride_B_base + expert_id;
|
||||||
|
StrideD* stride_D = stride_D_base + expert_id;
|
||||||
|
*stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
|
||||||
|
*stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
|
||||||
|
*stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OffsetFunctor, typename LayoutFunctor,
|
||||||
|
typename StrideFunctor>
|
||||||
|
__global__ void cutlassMxfp8GroupedMmPreComputeKernel(
|
||||||
|
int* problem_sizes, OffsetFunctor offset_functor,
|
||||||
|
LayoutFunctor layout_functor, StrideFunctor stride_functor) {
|
||||||
|
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
|
||||||
|
int m = problem_sizes[expert_id * 3 + 0];
|
||||||
|
int n = problem_sizes[expert_id * 3 + 1];
|
||||||
|
int k = problem_sizes[expert_id * 3 + 2];
|
||||||
|
|
||||||
|
offset_functor(expert_id, m, n, k);
|
||||||
|
layout_functor(expert_id, m, n, k);
|
||||||
|
stride_functor(expert_id, m, n, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace expert_specialization
|
||||||
179
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh
Normal file
179
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
// Adapted from SGLang:
|
||||||
|
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
|
||||||
|
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
|
||||||
|
|
||||||
|
namespace expert_specialization {
|
||||||
|
|
||||||
|
template <typename GemmTraits>
|
||||||
|
void cutlass_mxfp8_grouped_mm_pre_compute(
|
||||||
|
torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs,
|
||||||
|
torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a,
|
||||||
|
torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa,
|
||||||
|
torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d,
|
||||||
|
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||||
|
using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor<GemmTraits>;
|
||||||
|
using ElementA = typename OffsetFunctor::ElementA;
|
||||||
|
using ElementB = typename OffsetFunctor::ElementB;
|
||||||
|
using ElementSF = typename OffsetFunctor::ElementSF;
|
||||||
|
using ElementD = typename OffsetFunctor::ElementD;
|
||||||
|
|
||||||
|
using LayoutFunctor = CutlassMxfp8GroupedMmLayoutFunctor<GemmTraits>;
|
||||||
|
using LayoutSFA = typename LayoutFunctor::LayoutSFA;
|
||||||
|
using LayoutSFB = typename LayoutFunctor::LayoutSFB;
|
||||||
|
|
||||||
|
using StrideFunctor = CutlassMxfp8GroupedMmStrideFunctor<GemmTraits>;
|
||||||
|
using StrideA = typename StrideFunctor::StrideA;
|
||||||
|
using StrideB = typename StrideFunctor::StrideB;
|
||||||
|
using StrideD = typename StrideFunctor::StrideD;
|
||||||
|
|
||||||
|
int num_experts = (int)expert_offsets.size(0);
|
||||||
|
TORCH_CHECK(num_experts <= 1024,
|
||||||
|
"Number of experts cannot exceed 1024, the maximum number of "
|
||||||
|
"threads per block.");
|
||||||
|
|
||||||
|
OffsetFunctor offset_functor(
|
||||||
|
reinterpret_cast<int*>(expert_offsets.data_ptr()),
|
||||||
|
reinterpret_cast<int*>(blockscale_offsets.data_ptr()),
|
||||||
|
reinterpret_cast<ElementA*>(a.data_ptr()),
|
||||||
|
reinterpret_cast<ElementB*>(b.data_ptr()),
|
||||||
|
reinterpret_cast<ElementSF*>(sfa.data_ptr()),
|
||||||
|
reinterpret_cast<ElementSF*>(sfb.data_ptr()),
|
||||||
|
reinterpret_cast<ElementD*>(d.data_ptr()),
|
||||||
|
reinterpret_cast<ElementA**>(a_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<ElementB**>(b_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<ElementSF**>(sfa_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<ElementSF**>(sfb_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<ElementD**>(d_ptrs.data_ptr()));
|
||||||
|
LayoutFunctor layout_functor(
|
||||||
|
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
||||||
|
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()));
|
||||||
|
StrideFunctor stride_functor(reinterpret_cast<StrideA*>(stride_a.data_ptr()),
|
||||||
|
reinterpret_cast<StrideB*>(stride_b.data_ptr()),
|
||||||
|
reinterpret_cast<StrideD*>(stride_d.data_ptr()));
|
||||||
|
cutlassMxfp8GroupedMmPreComputeKernel<<<1, num_experts, 0, stream>>>(
|
||||||
|
static_cast<int*>(problem_sizes.data_ptr()), offset_functor,
|
||||||
|
layout_functor, stride_functor);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename GemmTraits>
|
||||||
|
void cutlass_mxfp8_grouped_mm(
|
||||||
|
const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs,
|
||||||
|
const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs,
|
||||||
|
const torch::Tensor& d_ptrs, const torch::Tensor& stride_a,
|
||||||
|
const torch::Tensor& stride_b, const torch::Tensor& stride_d,
|
||||||
|
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||||
|
const torch::Tensor& problem_sizes, cudaStream_t stream) {
|
||||||
|
using Gemm = typename GemmTraits::Gemm;
|
||||||
|
using ElementA = typename Gemm::ElementA;
|
||||||
|
using ElementB = typename Gemm::ElementB;
|
||||||
|
using ElementSF = typename GemmTraits::ElementSF;
|
||||||
|
using ElementD = typename GemmTraits::ElementOutput;
|
||||||
|
using StrideA = typename GemmTraits::StrideA;
|
||||||
|
using StrideB = typename GemmTraits::StrideB;
|
||||||
|
using StrideD = typename GemmTraits::StrideD;
|
||||||
|
using LayoutSFA = typename GemmTraits::LayoutSFA;
|
||||||
|
using LayoutSFB = typename GemmTraits::LayoutSFB;
|
||||||
|
using UnderlyingProblemShape =
|
||||||
|
typename GemmTraits::ProblemShape::UnderlyingProblemShape;
|
||||||
|
|
||||||
|
cutlass::KernelHardwareInfo hw_info;
|
||||||
|
hw_info.device_id = c10::cuda::current_device();
|
||||||
|
hw_info.sm_count =
|
||||||
|
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||||
|
hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster;
|
||||||
|
hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster;
|
||||||
|
|
||||||
|
int num_experts = (int)problem_sizes.size(0);
|
||||||
|
|
||||||
|
UnderlyingProblemShape* underlying_problem_shape =
|
||||||
|
reinterpret_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||||
|
|
||||||
|
typename Gemm::Arguments arguments = {
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||||
|
{num_experts, underlying_problem_shape, nullptr},
|
||||||
|
{reinterpret_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<StrideA*>(stride_a.data_ptr()),
|
||||||
|
reinterpret_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<StrideB*>(stride_b.data_ptr()),
|
||||||
|
reinterpret_cast<const ElementSF**>(sfa_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
||||||
|
reinterpret_cast<const ElementSF**>(sfb_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())},
|
||||||
|
{{},
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
reinterpret_cast<ElementD**>(d_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<StrideD*>(stride_d.data_ptr())},
|
||||||
|
hw_info,
|
||||||
|
{} // Scheduler
|
||||||
|
};
|
||||||
|
|
||||||
|
Gemm gemm;
|
||||||
|
|
||||||
|
auto can_implement_status = gemm.can_implement(arguments);
|
||||||
|
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||||
|
"Failed to implement GEMM");
|
||||||
|
|
||||||
|
torch::TensorOptions options_uint8 =
|
||||||
|
torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device());
|
||||||
|
size_t workspace_size = gemm.get_workspace_size(arguments);
|
||||||
|
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
|
||||||
|
|
||||||
|
auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||||
|
|
||||||
|
status = gemm.run(stream, nullptr, true); // Enable PDL
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutType>
|
||||||
|
void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
|
||||||
|
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
|
||||||
|
const torch::Tensor& sfb, torch::Tensor& d,
|
||||||
|
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||||
|
int num_experts = (int)problem_sizes.size(0);
|
||||||
|
torch::TensorOptions options_int64 =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||||
|
torch::TensorOptions options_int32 =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt32).device(a.device());
|
||||||
|
|
||||||
|
torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
|
||||||
|
torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
|
||||||
|
torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64);
|
||||||
|
torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64);
|
||||||
|
torch::Tensor d_ptrs = torch::empty(num_experts, options_int64);
|
||||||
|
|
||||||
|
torch::Tensor stride_a = torch::empty(num_experts, options_int64);
|
||||||
|
torch::Tensor stride_b = torch::empty(num_experts, options_int64);
|
||||||
|
torch::Tensor stride_d = torch::empty(num_experts, options_int64);
|
||||||
|
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
|
||||||
|
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);
|
||||||
|
|
||||||
|
using GemmTraits = CutlassMxfp8GroupedMmGemmTraits<MMA1SMConfig, OutType>;
|
||||||
|
cutlass_mxfp8_grouped_mm_pre_compute<GemmTraits>(
|
||||||
|
a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d,
|
||||||
|
layout_sfa, layout_sfb, a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||||
|
blockscale_offsets, stream);
|
||||||
|
cutlass_mxfp8_grouped_mm<GemmTraits>(
|
||||||
|
a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d,
|
||||||
|
layout_sfa, layout_sfb, problem_sizes, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace expert_specialization
|
||||||
127
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh
Normal file
127
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
// Adapted from SGLang:
|
||||||
|
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_traits.cuh
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Misc
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/arch/arch.h"
|
||||||
|
#include "cutlass/arch/mma.h"
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||||
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||||
|
#include "cutlass/layout/layout.h"
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
|
#include "cutlass/numeric_size.h"
|
||||||
|
|
||||||
|
// Collective Builder
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
|
||||||
|
#include "cutlass/epilogue/thread/activation.h"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
// Integration
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||||
|
|
||||||
|
namespace expert_specialization {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
// Different configs for 1SM and 2SM MMA kernel
|
||||||
|
struct MMA1SMConfig {
|
||||||
|
using MmaTileShape = Shape<_128, _128, _128>;
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||||
|
const static dim3 preferred_cluster;
|
||||||
|
const static dim3 fallback_cluster;
|
||||||
|
};
|
||||||
|
const dim3 MMA1SMConfig::preferred_cluster(1, 4, 1);
|
||||||
|
const dim3 MMA1SMConfig::fallback_cluster(1, 2, 1);
|
||||||
|
|
||||||
|
template <typename _MMAConfig, typename OutputDtype>
|
||||||
|
struct CutlassMxfp8GroupedMmGemmTraits {
|
||||||
|
using MMAConfig = _MMAConfig;
|
||||||
|
using ElementInput = cutlass::float_e4m3_t;
|
||||||
|
using ElementOutput = OutputDtype;
|
||||||
|
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||||
|
|
||||||
|
// A matrix configuration
|
||||||
|
using ElementA = cutlass::mx_float8_t<ElementInput>;
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
constexpr static int AlignmentA = 32;
|
||||||
|
|
||||||
|
// B matrix configuration
|
||||||
|
using ElementB = cutlass::mx_float8_t<ElementInput>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
constexpr static int AlignmentB = 32;
|
||||||
|
|
||||||
|
// C/D matrix configuration
|
||||||
|
using ElementC = void;
|
||||||
|
using ElementD = ElementOutput;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
|
constexpr static int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
constexpr static int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
using ElementAccumulator = float;
|
||||||
|
|
||||||
|
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||||
|
using CustomEVTIdentity = // acc
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<
|
||||||
|
cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator,
|
||||||
|
RoundStyle>,
|
||||||
|
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||||
|
|
||||||
|
// Core kernel configurations
|
||||||
|
using ArchTag = cutlass::arch::Sm100;
|
||||||
|
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||||
|
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||||
|
|
||||||
|
// Runtime Cluster Shape
|
||||||
|
using ClusterShape = Shape<int32_t, int32_t, _1>;
|
||||||
|
|
||||||
|
// Define Epilogue
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, typename MMAConfig::MmaTileShape,
|
||||||
|
ClusterShape, Shape<_64, _64>, ElementAccumulator, ElementAccumulator,
|
||||||
|
ElementC, LayoutC*, AlignmentC, ElementD, LayoutD*, AlignmentD,
|
||||||
|
typename MMAConfig::EpilogueSchedule,
|
||||||
|
CustomEVTIdentity>::CollectiveOp;
|
||||||
|
|
||||||
|
// Define Mainloop
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
|
||||||
|
LayoutB*, AlignmentB, ElementAccumulator,
|
||||||
|
typename MMAConfig::MmaTileShape, ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
typename MMAConfig::KernelSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
// Define GemmKernel
|
||||||
|
using GemmKernel =
|
||||||
|
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||||
|
CollectiveEpilogue>;
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
|
||||||
|
using ElementSF = typename Gemm::GemmKernel::ElementSF;
|
||||||
|
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||||
|
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||||
|
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||||
|
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||||
|
using LayoutSFA =
|
||||||
|
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||||
|
using LayoutSFB =
|
||||||
|
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||||
|
using Sm1xxBlkScaledConfig =
|
||||||
|
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace expert_specialization
|
||||||
60
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu
Normal file
60
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
// Adapted from SGLang:
|
||||||
|
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "mxfp8_experts_quant.cuh"
|
||||||
|
|
||||||
|
void mxfp8_experts_quant(const torch::Tensor& input,
|
||||||
|
const torch::Tensor& problem_sizes,
|
||||||
|
const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& blockscale_offsets,
|
||||||
|
torch::Tensor& quant_output,
|
||||||
|
torch::Tensor& scale_factor) {
|
||||||
|
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||||
|
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
|
||||||
|
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
|
||||||
|
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
|
||||||
|
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||||
|
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||||
|
"problem_sizes must be int32");
|
||||||
|
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
|
||||||
|
"expert_offsets must be int32");
|
||||||
|
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
|
||||||
|
"blockscale_offsets must be int32");
|
||||||
|
|
||||||
|
auto groups = problem_sizes.size(0);
|
||||||
|
TORCH_CHECK(
|
||||||
|
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
|
||||||
|
"expert_offsets must be 1D and have size equal to the number of groups");
|
||||||
|
TORCH_CHECK(
|
||||||
|
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
|
||||||
|
"blockscale_offsets must be 1D and have size equal to the number of "
|
||||||
|
"groups");
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
if (input.dtype() == torch::kBFloat16) {
|
||||||
|
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
|
||||||
|
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||||
|
scale_factor);
|
||||||
|
} else if (input.dtype() == torch::kFloat16) {
|
||||||
|
expert_specialization::launch_mxfp8_experts_quant<__half>(
|
||||||
|
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||||
|
scale_factor);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"No implemented mxfp8_experts_quant for "
|
||||||
|
"current device");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("mxfp8_experts_quant", mxfp8_experts_quant);
|
||||||
|
}
|
||||||
414
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh
Normal file
414
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
// Adapted from SGLang:
|
||||||
|
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <cuda/ptx>
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
namespace expert_specialization {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
constexpr uint32_t THREAD_BLOCK_SIZE = 128;
|
||||||
|
constexpr uint32_t WARP_SIZE = 32;
|
||||||
|
constexpr int BLOCK_M = 128;
|
||||||
|
constexpr int BLOCK_K = 128;
|
||||||
|
using ThrLayout = Layout<Shape<_16, _8>, Stride<_8, _1>>;
|
||||||
|
using ValLayout = Layout<Shape<_1, _16>>;
|
||||||
|
using SfR2SThrLayout = Layout<Shape<_16, _4>, Stride<_4, _1>>;
|
||||||
|
using SfR2SValLayout = Layout<Shape<_1, _1>>;
|
||||||
|
using ScaleFactorTileLayout =
|
||||||
|
Layout<Shape<Shape<_32, _4>, _4>, Stride<Stride<_16, _4>, _1>>;
|
||||||
|
|
||||||
|
// Fast reciprocal.
|
||||||
|
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||||
|
float b;
|
||||||
|
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some code references TRT-LLM:
|
||||||
|
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/quantization.cuh
|
||||||
|
template <typename FragmentS, typename FragmentD>
|
||||||
|
__inline__ __device__ uint8_t cvt_warp_fp16_to_mxfp8(FragmentS& fragment_s,
|
||||||
|
FragmentD& fragment_d) {
|
||||||
|
using FragmentSLayout = typename FragmentS::layout_type;
|
||||||
|
using FragmentDLayout = typename FragmentD::layout_type;
|
||||||
|
FragmentSLayout fragment_s_layout;
|
||||||
|
FragmentDLayout fragment_d_layout;
|
||||||
|
static_assert(is_static<FragmentSLayout>::value &&
|
||||||
|
size(fragment_s_layout) == 16);
|
||||||
|
static_assert(is_static<FragmentDLayout>::value &&
|
||||||
|
size(fragment_d_layout) == 16);
|
||||||
|
|
||||||
|
constexpr int eles_per_thr = 16;
|
||||||
|
using ValType = typename FragmentS::element_type;
|
||||||
|
using VecType = std::conditional_t<std::is_same_v<ValType, __nv_bfloat16>,
|
||||||
|
__nv_bfloat162, __half2>;
|
||||||
|
VecType vec[8];
|
||||||
|
// Assign vals
|
||||||
|
vec[0].x = fragment_s(Int<0>{});
|
||||||
|
vec[0].y = fragment_s(Int<1>{});
|
||||||
|
vec[1].x = fragment_s(Int<2>{});
|
||||||
|
vec[1].y = fragment_s(Int<3>{});
|
||||||
|
vec[2].x = fragment_s(Int<4>{});
|
||||||
|
vec[2].y = fragment_s(Int<5>{});
|
||||||
|
vec[3].x = fragment_s(Int<6>{});
|
||||||
|
vec[3].y = fragment_s(Int<7>{});
|
||||||
|
vec[4].x = fragment_s(Int<8>{});
|
||||||
|
vec[4].y = fragment_s(Int<9>{});
|
||||||
|
vec[5].x = fragment_s(Int<10>{});
|
||||||
|
vec[5].y = fragment_s(Int<11>{});
|
||||||
|
vec[6].x = fragment_s(Int<12>{});
|
||||||
|
vec[6].y = fragment_s(Int<13>{});
|
||||||
|
vec[7].x = fragment_s(Int<14>{});
|
||||||
|
vec[7].y = fragment_s(Int<15>{});
|
||||||
|
|
||||||
|
auto local_max = __habs2(vec[0]);
|
||||||
|
for (int i = 1; i < eles_per_thr / 2; i++) {
|
||||||
|
local_max = __hmax2(__habs2(vec[i]), local_max);
|
||||||
|
}
|
||||||
|
local_max = __hmax2(__shfl_xor_sync(uint32_t(-1), local_max, 1), local_max);
|
||||||
|
|
||||||
|
// Get the final absolute maximum values.
|
||||||
|
float block_max(0.0f);
|
||||||
|
if constexpr (std::is_same_v<ValType, __nv_bfloat16>) {
|
||||||
|
block_max = __bfloat162float(__hmax(local_max.x, local_max.y));
|
||||||
|
} else {
|
||||||
|
block_max = __half2float(__hmax(local_max.x, local_max.y));
|
||||||
|
}
|
||||||
|
// Get the SF (max value of the vector / max value of mxfp8).
|
||||||
|
float sf_val = block_max * reciprocal_approximate_ftz(448.0f);
|
||||||
|
// 8 bits representation of the SF.
|
||||||
|
uint8_t fp8_sf_val;
|
||||||
|
|
||||||
|
__nv_fp8_e8m0 tmp_sf_val;
|
||||||
|
tmp_sf_val.__x =
|
||||||
|
__nv_cvt_float_to_e8m0(sf_val, __NV_SATFINITE, cudaRoundPosInf);
|
||||||
|
sf_val = static_cast<float>(tmp_sf_val);
|
||||||
|
fp8_sf_val = tmp_sf_val.__x;
|
||||||
|
// Get the output scale (reciprocal of the SFValue).
|
||||||
|
float output_scale =
|
||||||
|
block_max != 0.f ? reciprocal_approximate_ftz(sf_val) : 0.0f;
|
||||||
|
|
||||||
|
// Convert the input to float.
|
||||||
|
float2 fp2_vals[eles_per_thr / 2];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < eles_per_thr / 2; i++) {
|
||||||
|
if constexpr (std::is_same_v<ValType, __half>) {
|
||||||
|
fp2_vals[i] = __half22float2(vec[i]);
|
||||||
|
} else {
|
||||||
|
fp2_vals[i] = __bfloat1622float2(vec[i]);
|
||||||
|
}
|
||||||
|
fp2_vals[i].x *= output_scale;
|
||||||
|
fp2_vals[i].y *= output_scale;
|
||||||
|
}
|
||||||
|
union {
|
||||||
|
uint8_t bytes[16];
|
||||||
|
__nv_fp8x2_e4m3 elts[8];
|
||||||
|
} u;
|
||||||
|
u.elts[0] = __nv_fp8x2_e4m3(fp2_vals[0]);
|
||||||
|
u.elts[1] = __nv_fp8x2_e4m3(fp2_vals[1]);
|
||||||
|
u.elts[2] = __nv_fp8x2_e4m3(fp2_vals[2]);
|
||||||
|
u.elts[3] = __nv_fp8x2_e4m3(fp2_vals[3]);
|
||||||
|
u.elts[4] = __nv_fp8x2_e4m3(fp2_vals[4]);
|
||||||
|
u.elts[5] = __nv_fp8x2_e4m3(fp2_vals[5]);
|
||||||
|
u.elts[6] = __nv_fp8x2_e4m3(fp2_vals[6]);
|
||||||
|
u.elts[7] = __nv_fp8x2_e4m3(fp2_vals[7]);
|
||||||
|
fragment_d(Int<0>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[0]);
|
||||||
|
fragment_d(Int<1>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[1]);
|
||||||
|
fragment_d(Int<2>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[2]);
|
||||||
|
fragment_d(Int<3>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[3]);
|
||||||
|
fragment_d(Int<4>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[4]);
|
||||||
|
fragment_d(Int<5>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[5]);
|
||||||
|
fragment_d(Int<6>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[6]);
|
||||||
|
fragment_d(Int<7>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[7]);
|
||||||
|
fragment_d(Int<8>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[8]);
|
||||||
|
fragment_d(Int<9>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[9]);
|
||||||
|
fragment_d(Int<10>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[10]);
|
||||||
|
fragment_d(Int<11>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[11]);
|
||||||
|
fragment_d(Int<12>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[12]);
|
||||||
|
fragment_d(Int<13>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[13]);
|
||||||
|
fragment_d(Int<14>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[14]);
|
||||||
|
fragment_d(Int<15>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[15]);
|
||||||
|
return fp8_sf_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TensorS, typename TensorP, typename TensorD,
|
||||||
|
typename TensorSharedSF, typename TensorSF, typename TiledCopyG2R,
|
||||||
|
typename TiledCopyR2G, typename TiledCopyR2S>
|
||||||
|
__inline__ __device__ void mxfp8_experts_quant_tile(
|
||||||
|
TensorS& tensor_s, TensorP& tensor_p, TensorD& tensor_d,
|
||||||
|
TensorSharedSF& tensor_shared_sf, TensorSF& tensor_sf, int m,
|
||||||
|
TiledCopyG2R& tiled_copy_g2r, TiledCopyR2G& tiled_copy_r2g,
|
||||||
|
TiledCopyR2S& tiled_copy_r2s) {
|
||||||
|
static_assert(size(get<0>(typename TensorS::layout_type{})) == 128 &&
|
||||||
|
size(get<1>(typename TensorS::layout_type{})) == 128 &&
|
||||||
|
stride(get<1>(typename TensorS::layout_type{})) == 1);
|
||||||
|
static_assert(size(get<0>(typename TensorD::layout_type{})) == 128 &&
|
||||||
|
size(get<1>(typename TensorD::layout_type{})) == 128 &&
|
||||||
|
stride(get<1>(typename TensorD::layout_type{})) == 1);
|
||||||
|
static_assert(size(get<0>(typename TensorP::layout_type{})) == 128 &&
|
||||||
|
size(get<1>(typename TensorP::layout_type{})) == 128);
|
||||||
|
static_assert(size(get<0>(typename TensorSharedSF::layout_type{})) == 128 &&
|
||||||
|
size(get<1>(typename TensorSharedSF::layout_type{})) == 4);
|
||||||
|
static_assert(size(get<0>(typename TensorSF::layout_type{})) == 128 &&
|
||||||
|
size(get<1>(typename TensorSF::layout_type{})) == 4);
|
||||||
|
|
||||||
|
using Tiler_MN = typename TiledCopyG2R::Tiler_MN;
|
||||||
|
auto tiler_mn = Tiler_MN{};
|
||||||
|
static_assert(size<0>(tiler_mn) == 16 && size<1>(tiler_mn) == 128);
|
||||||
|
|
||||||
|
auto tiled_tensor_s = tiled_divide(tensor_s, tiler_mn);
|
||||||
|
auto tiled_tensor_p = tiled_divide(tensor_p, tiler_mn);
|
||||||
|
auto tiled_tensor_d = tiled_divide(tensor_d, tiler_mn);
|
||||||
|
static_assert(size<2>(tiled_tensor_s) == 1);
|
||||||
|
static_assert(size<2>(tiled_tensor_p) == 1);
|
||||||
|
static_assert(size<2>(tiled_tensor_d) == 1);
|
||||||
|
auto squeeze_tiled_tensor_s = take<0, 2>(tiled_tensor_s);
|
||||||
|
auto squeeze_tiled_tensor_p = take<0, 2>(tiled_tensor_p);
|
||||||
|
auto squeeze_tiled_tensor_d = take<0, 2>(tiled_tensor_d);
|
||||||
|
|
||||||
|
using SF_Tiler_MN = typename TiledCopyR2S::Tiler_MN;
|
||||||
|
auto sf_tiler_mn = SF_Tiler_MN{};
|
||||||
|
static_assert(size<0>(sf_tiler_mn) == 16 && size<1>(sf_tiler_mn) == 4);
|
||||||
|
|
||||||
|
auto tiled_tensor_sf = tiled_divide(tensor_sf, sf_tiler_mn);
|
||||||
|
auto tiled_tensor_shared_sf = tiled_divide(tensor_shared_sf, sf_tiler_mn);
|
||||||
|
auto squeeze_tiled_tensor_sf = take<0, 2>(tiled_tensor_sf);
|
||||||
|
auto squeeze_tiled_tensor_shared_sf = take<0, 2>(tiled_tensor_shared_sf);
|
||||||
|
|
||||||
|
constexpr int tile_loop_count = size<1>(tiled_tensor_s);
|
||||||
|
constexpr int rows_in_tile = 16;
|
||||||
|
// We don't need to clear shared memory
|
||||||
|
// clear(squeeze_tiled_tensor_shared_sf);
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int t = 0; t < tile_loop_count; t++) {
|
||||||
|
if (t * rows_in_tile >= m) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto current_copy_tile_s = tensor<0>(squeeze_tiled_tensor_s(_, t));
|
||||||
|
auto current_copy_tile_p = tensor<0>(squeeze_tiled_tensor_p(_, t));
|
||||||
|
auto current_copy_tile_d = tensor<0>(squeeze_tiled_tensor_d(_, t));
|
||||||
|
auto current_copy_tile_sf = tensor<0>(squeeze_tiled_tensor_sf(_, t));
|
||||||
|
auto current_copy_tile_shared_sf =
|
||||||
|
tensor<0>(squeeze_tiled_tensor_shared_sf(_, t));
|
||||||
|
|
||||||
|
// Global to Register copy
|
||||||
|
auto thr_copy_g2r = tiled_copy_g2r.get_thread_slice(threadIdx.x);
|
||||||
|
auto thr_tile_g2r_s = thr_copy_g2r.partition_S(current_copy_tile_s);
|
||||||
|
auto thr_tile_g2r_p = thr_copy_g2r.partition_S(current_copy_tile_p);
|
||||||
|
auto input_fragment = make_fragment_like(thr_tile_g2r_s);
|
||||||
|
|
||||||
|
// Register to Global copy
|
||||||
|
auto thr_copy_r2g = tiled_copy_r2g.get_thread_slice(threadIdx.x);
|
||||||
|
auto thr_tile_r2g_d = thr_copy_r2g.partition_D(current_copy_tile_d);
|
||||||
|
auto thr_tile_r2g_p = thr_copy_r2g.partition_D(current_copy_tile_p);
|
||||||
|
auto output_fragment = make_fragment_like(thr_tile_r2g_d);
|
||||||
|
|
||||||
|
// Register to Shared copy
|
||||||
|
auto thr_copy_r2s = tiled_copy_r2s.get_thread_slice(threadIdx.x / 2);
|
||||||
|
auto thr_tile_r2s_shared_sf =
|
||||||
|
thr_copy_r2s.partition_D(current_copy_tile_shared_sf);
|
||||||
|
auto shared_sf_fragment = make_fragment_like(thr_tile_r2s_shared_sf);
|
||||||
|
|
||||||
|
// CopyG2R & convert & CopyR2G
|
||||||
|
copy_if(tiled_copy_g2r, thr_tile_g2r_p, thr_tile_g2r_s, input_fragment);
|
||||||
|
uint8_t fp8_sf_val =
|
||||||
|
cvt_warp_fp16_to_mxfp8(input_fragment, output_fragment);
|
||||||
|
copy_if(tiled_copy_r2g, thr_tile_r2g_p, output_fragment, thr_tile_r2g_d);
|
||||||
|
shared_sf_fragment[0] = fp8_sf_val;
|
||||||
|
|
||||||
|
// Before first copy r2s, clear shared memory and wait previous group
|
||||||
|
if (t == 0 && threadIdx.x == 0) {
|
||||||
|
// Wait for the group to have completed reading from shared memory.
|
||||||
|
cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t<0>());
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x % 2 == 0) {
|
||||||
|
copy(tiled_copy_r2s, shared_sf_fragment, thr_tile_r2s_shared_sf);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for shared memory writes to be visible to TMA engine.
|
||||||
|
cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); // b)
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
cuda::ptx::cp_async_bulk(cuda::ptx::space_global, cuda::ptx::space_shared,
|
||||||
|
squeeze_tiled_tensor_sf.data().get(),
|
||||||
|
squeeze_tiled_tensor_shared_sf.data().get(), 512);
|
||||||
|
// Wait for TMA transfer to have finished reading shared memory.
|
||||||
|
// Create a "bulk async-group" out of the previous bulk copy operation.
|
||||||
|
cuda::ptx::cp_async_bulk_commit_group();
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T_IN, typename TiledCopyG2R, typename TiledCopyR2G,
|
||||||
|
typename TiledCopyR2S>
|
||||||
|
__global__ void mxfp8_experts_quant_kernel(
|
||||||
|
const T_IN* input, const int* problem_sizes, const int* expert_offsets,
|
||||||
|
const int* blockscale_offsets, cutlass::float_e4m3_t* quant_output,
|
||||||
|
uint8_t* scale_factor, int groups, TiledCopyG2R tiled_copy_g2r,
|
||||||
|
TiledCopyR2G tiled_copy_r2g, TiledCopyR2S tiled_copy_r2s) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
|
||||||
|
__shared__ __align__(512) uint8_t shared_memory[512];
|
||||||
|
ScaleFactorTileLayout scale_factor_tile_layout{};
|
||||||
|
auto scale_factor_shared =
|
||||||
|
make_tensor(make_smem_ptr(shared_memory),
|
||||||
|
scale_factor_tile_layout); // ((_32,_4), _4):((_16,_4), _1)
|
||||||
|
// TODO: Transform Groupwise Schedule into a more efficient Schedule
|
||||||
|
for (int g = 0; g < groups; g++) {
|
||||||
|
int m = problem_sizes[g * 3 + 0];
|
||||||
|
int k = problem_sizes[g * 3 + 2];
|
||||||
|
int64_t expert_offset = static_cast<int64_t>(expert_offsets[g]);
|
||||||
|
int64_t blockscale_offset = static_cast<int64_t>(blockscale_offsets[g]);
|
||||||
|
|
||||||
|
auto input_tensor = make_tensor(
|
||||||
|
make_gmem_ptr(input + expert_offset * k),
|
||||||
|
make_layout(make_shape(m, k),
|
||||||
|
LayoutRight{})); // (M, K):(K, 1) half_t/bfloat16_t
|
||||||
|
|
||||||
|
auto quant_output_tensor = make_tensor(
|
||||||
|
make_gmem_ptr(quant_output + expert_offset * k),
|
||||||
|
make_layout(make_shape(m, k),
|
||||||
|
LayoutRight{})); // (M, K):(K, 1) cutlass::float_e4m3_t
|
||||||
|
|
||||||
|
auto scale_factor_shape = make_shape(ceil_div(m, 128) * 128, k / 32);
|
||||||
|
auto scale_factor_layout = tile_to_shape(scale_factor_tile_layout,
|
||||||
|
scale_factor_shape, LayoutRight{});
|
||||||
|
// layout<0>(layout<0>(scale_factor_layout)) (_32,_4):(_16,_4) -- static
|
||||||
|
// layout<1>(layout<0>(scale_factor_layout)) M_align_128 / 128 -- dynamic
|
||||||
|
// shape dynamic stride layout<0>(layout<1>(scale_factor_layout)) _4:_1 --
|
||||||
|
// static layout<1>(layout<1>(scale_factor_layout)) (K / 32) / 4 : _512 --
|
||||||
|
// dynamic shape static stride
|
||||||
|
|
||||||
|
// Reshape to zipped layout for 1D indexing
|
||||||
|
auto zipped_scale_factor_layout = make_layout(
|
||||||
|
make_layout(layout<0>(layout<0>(scale_factor_layout)),
|
||||||
|
layout<0>(layout<1>(scale_factor_layout))),
|
||||||
|
make_layout(
|
||||||
|
layout<1>(layout<0>(scale_factor_layout)),
|
||||||
|
layout<1>(layout<1>(
|
||||||
|
scale_factor_layout)))); // (((_32,_4),_4),(M_align_128 /
|
||||||
|
// 128,(K / 32) /
|
||||||
|
// 4)):(((_16,_4),_1),(?,_512))
|
||||||
|
|
||||||
|
auto scale_factor_tensor =
|
||||||
|
make_tensor(make_gmem_ptr(scale_factor + blockscale_offset * (k / 32)),
|
||||||
|
zipped_scale_factor_layout);
|
||||||
|
|
||||||
|
// Used for cases where M is not divisible by 128 (most scenarios).
|
||||||
|
auto input_shape = shape(input_tensor); // (M, K):(K, 1)
|
||||||
|
auto identity_tensor = make_identity_tensor(input_shape);
|
||||||
|
auto predict_tensor = cute::lazy::transform(
|
||||||
|
identity_tensor, [&](auto c) { return elem_less(c, input_shape); });
|
||||||
|
|
||||||
|
// (_128, _128)
|
||||||
|
auto tiler = make_shape(Int<BLOCK_M>{}, Int<BLOCK_K>{});
|
||||||
|
|
||||||
|
auto tiled_input_tensor = zipped_divide(
|
||||||
|
input_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
|
||||||
|
auto tiled_quant_output_tensor =
|
||||||
|
zipped_divide(quant_output_tensor,
|
||||||
|
tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
|
||||||
|
auto tiled_predict_tensor = zipped_divide(
|
||||||
|
predict_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
|
||||||
|
|
||||||
|
auto total_tiles =
|
||||||
|
size<1>(tiled_input_tensor); // cdiv(M, 128) * cdiv(K, 128)
|
||||||
|
decltype(total_tiles) blk_offset = blockIdx.x;
|
||||||
|
while (blk_offset < total_tiles) {
|
||||||
|
auto current_input_tile = tensor<0>(tiled_input_tensor(_, blk_offset));
|
||||||
|
auto current_quant_output_tile =
|
||||||
|
tensor<0>(tiled_quant_output_tensor(_, blk_offset));
|
||||||
|
auto current_predict_tile =
|
||||||
|
tensor<0>(tiled_predict_tensor(_, blk_offset));
|
||||||
|
auto current_scale_factor_tile =
|
||||||
|
tensor<0>(scale_factor_tensor(_, blk_offset));
|
||||||
|
|
||||||
|
mxfp8_experts_quant_tile<
|
||||||
|
decltype(current_input_tile), decltype(current_predict_tile),
|
||||||
|
decltype(current_quant_output_tile), decltype(scale_factor_shared),
|
||||||
|
decltype(current_scale_factor_tile), TiledCopyG2R, TiledCopyR2G,
|
||||||
|
TiledCopyR2S>(current_input_tile, current_predict_tile,
|
||||||
|
current_quant_output_tile, scale_factor_shared,
|
||||||
|
current_scale_factor_tile, m, tiled_copy_g2r,
|
||||||
|
tiled_copy_r2g, tiled_copy_r2s);
|
||||||
|
blk_offset += gridDim.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T_IN>
|
||||||
|
void launch_mxfp8_experts_quant(const torch::Tensor& input,
|
||||||
|
const torch::Tensor& problem_sizes,
|
||||||
|
const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& blockscale_offsets,
|
||||||
|
torch::Tensor& quant_output,
|
||||||
|
torch::Tensor& scale_factor) {
|
||||||
|
ThrLayout thr_layout{};
|
||||||
|
ValLayout val_layout{};
|
||||||
|
SfR2SThrLayout r2s_thr_layout{};
|
||||||
|
SfR2SValLayout r2s_val_layout{};
|
||||||
|
|
||||||
|
using CopyOpG2R =
|
||||||
|
UniversalCopy<cutlass::AlignedArray<T_IN, size(val_layout)>>;
|
||||||
|
using CopyAtomG2R = cute::Copy_Atom<CopyOpG2R, T_IN>;
|
||||||
|
auto tiled_copy_g2r = cute::make_tiled_copy(
|
||||||
|
CopyAtomG2R{}, thr_layout, val_layout); // Tiler_MN: (16, 128)
|
||||||
|
|
||||||
|
using CopyOpR2G = UniversalCopy<
|
||||||
|
cutlass::AlignedArray<cutlass::float_e4m3_t, size(val_layout)>>;
|
||||||
|
using CopyAtomR2G = cute::Copy_Atom<CopyOpR2G, cutlass::float_e4m3_t>;
|
||||||
|
auto tiled_copy_r2g = cute::make_tiled_copy(
|
||||||
|
CopyAtomR2G{}, thr_layout, val_layout); // Tiler_MN: (16, 128)
|
||||||
|
|
||||||
|
using CopyOpR2S =
|
||||||
|
UniversalCopy<cutlass::AlignedArray<uint8_t, size(r2s_val_layout)>>;
|
||||||
|
using CopyAtomR2S = cute::Copy_Atom<CopyOpR2S, uint8_t>;
|
||||||
|
auto tiled_copy_r2s = cute::make_tiled_copy(
|
||||||
|
CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4)
|
||||||
|
|
||||||
|
int max_active_blocks_per_sm = -1;
|
||||||
|
AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
|
&max_active_blocks_per_sm,
|
||||||
|
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
|
||||||
|
decltype(tiled_copy_r2g),
|
||||||
|
decltype(tiled_copy_r2s)>,
|
||||||
|
THREAD_BLOCK_SIZE, 0));
|
||||||
|
|
||||||
|
dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount *
|
||||||
|
max_active_blocks_per_sm,
|
||||||
|
1, 1);
|
||||||
|
dim3 block(THREAD_BLOCK_SIZE, 1, 1);
|
||||||
|
int num_experts = (int)problem_sizes.size(0);
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
|
||||||
|
decltype(tiled_copy_r2g), decltype(tiled_copy_r2s)>
|
||||||
|
<<<grid, block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const T_IN*>(input.data_ptr()),
|
||||||
|
reinterpret_cast<const int*>(problem_sizes.data_ptr()),
|
||||||
|
reinterpret_cast<const int*>(expert_offsets.data_ptr()),
|
||||||
|
reinterpret_cast<const int*>(blockscale_offsets.data_ptr()),
|
||||||
|
reinterpret_cast<cutlass::float_e4m3_t*>(quant_output.data_ptr()),
|
||||||
|
reinterpret_cast<uint8_t*>(scale_factor.data_ptr()), num_experts,
|
||||||
|
tiled_copy_g2r, tiled_copy_r2g, tiled_copy_r2s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace expert_specialization
|
||||||
52
csrc/moe/router_gemm.cu
Normal file
52
csrc/moe/router_gemm.cu
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
// bf16 x bf16 -> fp32 router GEMM via cuBLAS.
|
||||||
|
// Uses CUBLAS_COMPUTE_32F so bf16 operands accumulate into fp32,
|
||||||
|
// matching TRT-LLM's cuBLAS fallback behaviour in dsv3RouterGemmOp.
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
|
||||||
|
// cuBLAS column-major math for row-major PyTorch tensors:
|
||||||
|
// weight[N,K]_row lda=K -> cuBLAS sees (K,N) col-major; CUBLAS_OP_T ->
|
||||||
|
// (N,K) input[M,K]_row ldb=K -> cuBLAS sees (K,M) col-major; CUBLAS_OP_N
|
||||||
|
// -> (K,M) out[M,N]_row ldc=N -> cuBLAS sees (N,M) col-major (written as
|
||||||
|
// output^T)
|
||||||
|
// cuBLAS: C(N,M) = weight(N,K) @ input(K,M) => C^T = output[M,N]
|
||||||
|
// params: m=N, n=M, k=K, lda=K (weight), ldb=K (input), ldc=N (output)
|
||||||
|
|
||||||
|
torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
|
||||||
|
torch::Tensor const& weight) {
|
||||||
|
TORCH_CHECK(input.dtype() == torch::kBFloat16,
|
||||||
|
"router_gemm_bf16_fp32: input must be bfloat16");
|
||||||
|
TORCH_CHECK(weight.dtype() == torch::kBFloat16,
|
||||||
|
"router_gemm_bf16_fp32: weight must be bfloat16");
|
||||||
|
TORCH_CHECK(input.dim() == 2 && weight.dim() == 2,
|
||||||
|
"router_gemm_bf16_fp32: input and weight must be 2-D");
|
||||||
|
TORCH_CHECK(input.size(1) == weight.size(1),
|
||||||
|
"router_gemm_bf16_fp32: inner dimensions must match");
|
||||||
|
|
||||||
|
int64_t const M = input.size(0);
|
||||||
|
int64_t const N = weight.size(0);
|
||||||
|
int64_t const K = input.size(1);
|
||||||
|
|
||||||
|
auto out = torch::empty({M, N}, input.options().dtype(torch::kFloat32));
|
||||||
|
|
||||||
|
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||||
|
TORCH_CUDABLAS_CHECK(
|
||||||
|
cublasSetStream(handle, at::cuda::getCurrentCUDAStream()));
|
||||||
|
|
||||||
|
float const alpha = 1.0f;
|
||||||
|
float const beta = 0.0f;
|
||||||
|
|
||||||
|
TORCH_CUDABLAS_CHECK(cublasGemmEx(
|
||||||
|
handle, CUBLAS_OP_T, CUBLAS_OP_N, static_cast<int>(N),
|
||||||
|
static_cast<int>(M), static_cast<int>(K), &alpha, weight.data_ptr(),
|
||||||
|
CUDA_R_16BF, static_cast<int>(K), input.data_ptr(), CUDA_R_16BF,
|
||||||
|
static_cast<int>(K), &beta, out.data_ptr(), CUDA_R_32F,
|
||||||
|
static_cast<int>(N), CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT));
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
@@ -125,6 +125,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
"Tensor)");
|
"Tensor)");
|
||||||
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
|
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
|
||||||
|
|
||||||
|
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
|
||||||
|
m.def("router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor");
|
||||||
|
m.impl("router_gemm_bf16_fp32", torch::kCUDA, &router_gemm_bf16_fp32);
|
||||||
|
|
||||||
// DeepSeek V3 optimized router GEMM for SM90+
|
// DeepSeek V3 optimized router GEMM for SM90+
|
||||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|||||||
18
csrc/ops.h
18
csrc/ops.h
@@ -269,13 +269,13 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
|||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
const int64_t n, const int64_t k, const bool swap_ab);
|
const int64_t n, const int64_t k, const bool swap_ab);
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes2,
|
||||||
const torch::Tensor& expert_num_tokens,
|
const torch::Tensor& expert_num_tokens,
|
||||||
const int64_t num_local_experts,
|
const int64_t num_local_experts,
|
||||||
const int64_t padded_m, const int64_t n,
|
const int64_t padded_m, const int64_t n,
|
||||||
const int64_t k);
|
const int64_t k);
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
@@ -371,7 +371,9 @@ void selective_scan_fwd(
|
|||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
||||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor>& initial_state_idx);
|
const std::optional<torch::Tensor>& initial_state_idx,
|
||||||
|
const std::optional<torch::Tensor>& cu_chunk_seqlen,
|
||||||
|
const std::optional<torch::Tensor>& last_chunk_indices);
|
||||||
|
|
||||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||||
|
|||||||
@@ -39,12 +39,12 @@ namespace vllm {
|
|||||||
template <class Type, bool UE8M0_SF = false>
|
template <class Type, bool UE8M0_SF = false>
|
||||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||||
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols,
|
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols,
|
||||||
int32_t num_padded_cols,
|
int32_t num_packed_cols,
|
||||||
Type const* __restrict__ in,
|
Type const* __restrict__ in,
|
||||||
float const* __restrict__ SFScale,
|
float const* __restrict__ SFScale,
|
||||||
uint32_t* __restrict__ out,
|
uint32_t* __restrict__ out,
|
||||||
uint32_t* __restrict__ SFout) {
|
uint32_t* __restrict__ SFout) {
|
||||||
using PackedVec = vllm::PackedVec<Type>;
|
using PackedVec = vllm::PackedVec<Type, CVT_FP4_PACK16>;
|
||||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||||
@@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
|
|
||||||
// Input tensor row/col loops.
|
// Input tensor row/col loops.
|
||||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
||||||
if (colIdx < num_padded_cols) {
|
if (colIdx < num_packed_cols) {
|
||||||
PackedVec in_vec;
|
PackedVec in_vec;
|
||||||
PackedVec in_vec2;
|
PackedVec in_vec2;
|
||||||
int64_t inOffset =
|
int64_t inOffset =
|
||||||
@@ -73,19 +73,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
|
|
||||||
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
||||||
if constexpr (CVT_FP4_PACK16) {
|
if constexpr (CVT_FP4_PACK16) {
|
||||||
ld256_or_zero_cg_u32<Type>(
|
ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
|
||||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
&reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||||
valid);
|
valid);
|
||||||
ld256_or_zero_cg_u32<Type>(
|
ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec2),
|
||||||
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 8],
|
&reinterpret_cast<const uint32_t*>(in)[inOffset2 * 8],
|
||||||
valid);
|
valid);
|
||||||
} else {
|
} else {
|
||||||
ld128_or_zero_cg_u32<Type>(
|
ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
|
||||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
&reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||||
valid);
|
valid);
|
||||||
ld128_or_zero_cg_u32<Type>(
|
ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec2),
|
||||||
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 4],
|
&reinterpret_cast<const uint32_t*>(in)[inOffset2 * 4],
|
||||||
valid);
|
valid);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute silu and mul
|
// Compute silu and mul
|
||||||
@@ -142,9 +142,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
|||||||
int const numBlocksPerSM =
|
int const numBlocksPerSM =
|
||||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||||
|
|
||||||
int sf_n_unpadded = int(n / CVT_FP4_ELTS_PER_THREAD);
|
int num_packed_cols = int(n / CVT_FP4_ELTS_PER_THREAD);
|
||||||
|
|
||||||
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
|
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
|
||||||
int grid_x = std::min(
|
int grid_x = std::min(
|
||||||
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||||
dim3 grid(grid_x, grid_y);
|
dim3 grid(grid_x, grid_y);
|
||||||
@@ -154,7 +154,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
|||||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||||
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
|
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
|
||||||
m, n, sf_n_unpadded, input_ptr, input_sf_ptr,
|
m, n, num_packed_cols, input_ptr, input_sf_ptr,
|
||||||
reinterpret_cast<uint32_t*>(output_ptr),
|
reinterpret_cast<uint32_t*>(output_ptr),
|
||||||
reinterpret_cast<uint32_t*>(sf_out));
|
reinterpret_cast<uint32_t*>(sf_out));
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
uint32_t* input_offset_by_experts,
|
uint32_t* input_offset_by_experts,
|
||||||
uint32_t* output_scale_offset_by_experts, int n_experts,
|
uint32_t* output_scale_offset_by_experts, int n_experts,
|
||||||
bool low_latency) {
|
bool low_latency) {
|
||||||
using PackedVec = PackedVec<Type>;
|
using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
|
||||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||||
@@ -155,7 +155,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|||||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||||
uint32_t* input_offset_by_experts,
|
uint32_t* input_offset_by_experts,
|
||||||
uint32_t* output_scale_offset_by_experts, int n_experts) {
|
uint32_t* output_scale_offset_by_experts, int n_experts) {
|
||||||
using PackedVec = PackedVec<Type>;
|
using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
|
||||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
Type const* __restrict__ in,
|
Type const* __restrict__ in,
|
||||||
float const* __restrict__ SFScale,
|
float const* __restrict__ SFScale,
|
||||||
uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) {
|
uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) {
|
||||||
using PackedVec = vllm::PackedVec<Type>;
|
using PackedVec = vllm::PackedVec<Type, CVT_FP4_PACK16>;
|
||||||
|
|
||||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
@@ -71,13 +71,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
// If we are outside valid rows OR outside valid columns -> Use Zeros
|
// If we are outside valid rows OR outside valid columns -> Use Zeros
|
||||||
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
||||||
if constexpr (CVT_FP4_PACK16) {
|
if constexpr (CVT_FP4_PACK16) {
|
||||||
ld256_or_zero_cg_u32<Type>(
|
ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
|
||||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
&reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||||
valid);
|
valid);
|
||||||
} else {
|
} else {
|
||||||
ld128_or_zero_cg_u32<Type>(
|
ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
|
||||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
&reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||||
valid);
|
valid);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sf_out =
|
auto sf_out =
|
||||||
@@ -114,7 +114,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
float const* __restrict__ SFScale,
|
float const* __restrict__ SFScale,
|
||||||
uint32_t* __restrict__ out,
|
uint32_t* __restrict__ out,
|
||||||
uint32_t* __restrict__ SFout) {
|
uint32_t* __restrict__ SFout) {
|
||||||
using PackedVec = PackedVec<Type>;
|
using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
|
||||||
|
|
||||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
@@ -139,13 +139,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
// If we are outside valid rows OR outside valid columns -> Use Zeros
|
// If we are outside valid rows OR outside valid columns -> Use Zeros
|
||||||
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
||||||
if constexpr (CVT_FP4_PACK16) {
|
if constexpr (CVT_FP4_PACK16) {
|
||||||
ld256_or_zero_cg_u32<Type>(
|
ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
|
||||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
&reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||||
valid);
|
valid);
|
||||||
} else {
|
} else {
|
||||||
ld128_or_zero_cg_u32<Type>(
|
ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
|
||||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
&reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||||
valid);
|
valid);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sf_out =
|
auto sf_out =
|
||||||
|
|||||||
@@ -19,8 +19,10 @@
|
|||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
|
|
||||||
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
|
#include "../../cuda_vec_utils.cuh"
|
||||||
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
|
|
||||||
|
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
|
||||||
|
CUDA_VERSION >= 12090
|
||||||
#define ELTS_PER_THREAD 16
|
#define ELTS_PER_THREAD 16
|
||||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 16;
|
constexpr int CVT_FP4_ELTS_PER_THREAD = 16;
|
||||||
constexpr bool CVT_FP4_PACK16 = true;
|
constexpr bool CVT_FP4_PACK16 = true;
|
||||||
@@ -34,68 +36,6 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Convert PyTorch cpp type to CUDA type
|
|
||||||
template <typename T>
|
|
||||||
struct CUDATypeConverter {
|
|
||||||
using Type = T;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct CUDATypeConverter<at::Half> {
|
|
||||||
using Type = half;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct CUDATypeConverter<at::BFloat16> {
|
|
||||||
using Type = __nv_bfloat16;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
|
||||||
template <typename T>
|
|
||||||
struct TypeConverter {
|
|
||||||
using Type = half2;
|
|
||||||
}; // keep for generality
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeConverter<half2> {
|
|
||||||
using Type = half;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeConverter<half> {
|
|
||||||
using Type = half2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeConverter<__nv_bfloat162> {
|
|
||||||
using Type = __nv_bfloat16;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeConverter<__nv_bfloat16> {
|
|
||||||
using Type = __nv_bfloat162;
|
|
||||||
};
|
|
||||||
|
|
||||||
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
|
|
||||||
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
|
|
||||||
// Define a 32 bytes packed data type.
|
|
||||||
template <class Type>
|
|
||||||
struct alignas(32) PackedVec {
|
|
||||||
typename TypeConverter<Type>::Type elts[8];
|
|
||||||
};
|
|
||||||
#else
|
|
||||||
// Define a 16 bytes packed data type.
|
|
||||||
template <class Type>
|
|
||||||
struct alignas(16) PackedVec {
|
|
||||||
typename TypeConverter<Type>::Type elts[4];
|
|
||||||
};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct PackedVec<__nv_fp8_e4m3> {
|
|
||||||
__nv_fp8x2_e4m3 elts[8];
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Int>
|
template <typename Int>
|
||||||
__host__ __device__ inline Int round_up(Int x, Int y) {
|
__host__ __device__ inline Int round_up(Int x, Int y) {
|
||||||
static_assert(std::is_integral_v<Int>,
|
static_assert(std::is_integral_v<Int>,
|
||||||
@@ -208,56 +148,6 @@ __device__ __forceinline__ float reciprocal_approximate_ftz(float a) {
|
|||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Type>
|
|
||||||
__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec<Type>& out,
|
|
||||||
const void* ptr,
|
|
||||||
bool pred) {
|
|
||||||
uint32_t r0, r1, r2, r3;
|
|
||||||
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" .reg .pred pr;\n"
|
|
||||||
" setp.ne.u32 pr, %4, 0;\n"
|
|
||||||
" mov.u32 %0, 0;\n"
|
|
||||||
" mov.u32 %1, 0;\n"
|
|
||||||
" mov.u32 %2, 0;\n"
|
|
||||||
" mov.u32 %3, 0;\n"
|
|
||||||
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n"
|
|
||||||
"}\n"
|
|
||||||
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
|
|
||||||
: "r"((int)pred), "l"(ptr));
|
|
||||||
|
|
||||||
*reinterpret_cast<uint4*>(&out) = uint4{r0, r1, r2, r3};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class Type>
|
|
||||||
__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec<Type>& out,
|
|
||||||
const void* ptr,
|
|
||||||
bool pred) {
|
|
||||||
uint32_t r0, r1, r2, r3, r4, r5, r6, r7;
|
|
||||||
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
" .reg .pred pr;\n"
|
|
||||||
" setp.ne.u32 pr, %8, 0;\n"
|
|
||||||
" mov.u32 %0, 0;\n"
|
|
||||||
" mov.u32 %1, 0;\n"
|
|
||||||
" mov.u32 %2, 0;\n"
|
|
||||||
" mov.u32 %3, 0;\n"
|
|
||||||
" mov.u32 %4, 0;\n"
|
|
||||||
" mov.u32 %5, 0;\n"
|
|
||||||
" mov.u32 %6, 0;\n"
|
|
||||||
" mov.u32 %7, 0;\n"
|
|
||||||
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n"
|
|
||||||
"}\n"
|
|
||||||
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6),
|
|
||||||
"=r"(r7)
|
|
||||||
: "r"((int)pred), "l"(ptr));
|
|
||||||
|
|
||||||
reinterpret_cast<uint4*>(&out)[0] = uint4{r0, r1, r2, r3};
|
|
||||||
reinterpret_cast<uint4*>(&out)[1] = uint4{r4, r5, r6, r7};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute SF output offset for swizzled tensor core layout.
|
// Compute SF output offset for swizzled tensor core layout.
|
||||||
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
|
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
|
||||||
// Caller must precompute: numKTiles = (numCols + 63) / 64
|
// Caller must precompute: numKTiles = (numCols + 63) / 64
|
||||||
@@ -315,8 +205,8 @@ __device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack,
|
|||||||
|
|
||||||
// Quantizes the provided PackedVec into the uint32_t output
|
// Quantizes the provided PackedVec into the uint32_t output
|
||||||
template <class Type, int CVT_FP4_NUM_THREADS_PER_SF, bool UE8M0_SF = false>
|
template <class Type, int CVT_FP4_NUM_THREADS_PER_SF, bool UE8M0_SF = false>
|
||||||
__device__ __forceinline__ fp4_packed_t
|
__device__ __forceinline__ fp4_packed_t cvt_warp_fp16_to_fp4(
|
||||||
cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
|
PackedVec<Type, CVT_FP4_PACK16>& vec, float SFScaleVal, uint8_t* SFout) {
|
||||||
// Get absolute maximum values among the local 8 values.
|
// Get absolute maximum values among the local 8 values.
|
||||||
auto localMax = __habs2(vec.elts[0]);
|
auto localMax = __habs2(vec.elts[0]);
|
||||||
|
|
||||||
@@ -372,11 +262,7 @@ cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||||
if constexpr (std::is_same_v<Type, half>) {
|
fp2Vals[i] = cast_to_float2(vec.elts[i]);
|
||||||
fp2Vals[i] = __half22float2(vec.elts[i]);
|
|
||||||
} else {
|
|
||||||
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
|
||||||
}
|
|
||||||
fp2Vals[i].x *= outputScale;
|
fp2Vals[i].x *= outputScale;
|
||||||
fp2Vals[i].y *= outputScale;
|
fp2Vals[i].y *= outputScale;
|
||||||
}
|
}
|
||||||
@@ -395,22 +281,19 @@ __device__ __forceinline__ float2 silu2(float2 x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class Type>
|
template <class Type>
|
||||||
__inline__ __device__ PackedVec<Type> compute_silu_mul(
|
__inline__ __device__ PackedVec<Type, CVT_FP4_PACK16> compute_silu_mul(
|
||||||
const PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
|
const PackedVec<Type, CVT_FP4_PACK16>& x_vec,
|
||||||
PackedVec<Type> result;
|
const PackedVec<Type, CVT_FP4_PACK16>& y_vec) {
|
||||||
|
PackedVec<Type, CVT_FP4_PACK16> result;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
||||||
// silu_mul in float32
|
// silu_mul in float32
|
||||||
if constexpr (std::is_same_v<Type, half>) {
|
using packed_t = typename PackedTypeConverter<Type>::Type;
|
||||||
float2 silu_vec = silu2(__half22float2(x_vec.elts[i]));
|
float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i]));
|
||||||
result.elts[i] = __float22half2_rn(
|
float2 y_f2 = cast_to_float2(y_vec.elts[i]);
|
||||||
__fmul2_rn(silu_vec, __half22float2(y_vec.elts[i])));
|
result.elts[i] = cast_to_packed<packed_t>(
|
||||||
} else {
|
make_float2(silu_vec.x * y_f2.x, silu_vec.y * y_f2.y));
|
||||||
float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i]));
|
|
||||||
result.elts[i] = __float22bfloat162_rn(
|
|
||||||
__fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i])));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -263,12 +263,10 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool SWAP_AB>
|
template <bool SWAP_AB>
|
||||||
__global__ void compute_pplx_data(int32_t* expert_offsets,
|
__global__ void compute_batched_moe_data(
|
||||||
int32_t* problem_sizes1,
|
int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2,
|
||||||
int32_t* problem_sizes2,
|
const int32_t* __restrict__ expert_num_tokens, const int padded_m,
|
||||||
const int32_t* __restrict__ expert_num_tokens,
|
const int n, const int k) {
|
||||||
const int padded_m, const int n,
|
|
||||||
const int k) {
|
|
||||||
int expert_idx = threadIdx.x;
|
int expert_idx = threadIdx.x;
|
||||||
expert_offsets[expert_idx] = expert_idx * padded_m;
|
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||||
|
|
||||||
@@ -289,24 +287,22 @@ __global__ void compute_pplx_data(int32_t* expert_offsets,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
void get_cutlass_batched_moe_mm_data_caller(
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
||||||
const torch::Tensor& expert_num_tokens,
|
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||||
const int64_t num_local_experts,
|
const int64_t k) {
|
||||||
const int64_t padded_m,
|
|
||||||
const int64_t n, const int64_t k) {
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||||
|
|
||||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||||
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
|
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||||
k);
|
k);
|
||||||
} else {
|
} else {
|
||||||
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
|
compute_batched_moe_data<true><<<1, num_local_experts, 0, stream>>>(
|
||||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
|
|||||||
@@ -82,13 +82,11 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
|||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
const int64_t n, const int64_t k, const bool swap_ab);
|
const int64_t n, const int64_t k, const bool swap_ab);
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
void get_cutlass_batched_moe_mm_data_caller(
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
||||||
const torch::Tensor& expert_num_tokens,
|
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||||
const int64_t num_local_experts,
|
const int64_t k);
|
||||||
const int64_t padded_m,
|
|
||||||
const int64_t n, const int64_t k);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
@@ -319,29 +317,30 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
|||||||
version_num, ". Required capability: 90, 100, or 120");
|
version_num, ". Required capability: 90, 100, or 120");
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes2,
|
||||||
const torch::Tensor& expert_num_tokens,
|
const torch::Tensor& expert_num_tokens,
|
||||||
const int64_t num_local_experts,
|
const int64_t num_local_experts,
|
||||||
const int64_t padded_m, const int64_t n,
|
const int64_t padded_m, const int64_t n,
|
||||||
const int64_t k) {
|
const int64_t k) {
|
||||||
// This function currently gets compiled only if we have a valid cutlass moe
|
// This function currently gets compiled only if we have a valid cutlass moe
|
||||||
// mm to run it for.
|
// mm to run it for.
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||||
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||||
problem_sizes2, expert_num_tokens,
|
problem_sizes2, expert_num_tokens,
|
||||||
num_local_experts, padded_m, n, k);
|
num_local_experts, padded_m, n, k);
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
false,
|
"No compiled get_cutlass_batched_moe_mm_data: no "
|
||||||
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
|
"cutlass_scaled_mm kernel "
|
||||||
"for CUDA device capability: ",
|
"for CUDA device capability: ",
|
||||||
version_num, ". Required capability: 90, 100, or 120");
|
version_num,
|
||||||
|
". Required capability: 90, 100, or 120");
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
|||||||
@@ -304,8 +304,9 @@ __device__ inline unsigned int min__(uint32_t a, uint32_t b) {
|
|||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||||
wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By,
|
wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap, const int M,
|
||||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
const int Bx, const int By, const scalar_t* B,
|
||||||
|
const scalar_t* __restrict__ A,
|
||||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||||
@@ -314,7 +315,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
#else
|
#else
|
||||||
constexpr bool use_mfma = false;
|
constexpr bool use_mfma = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
using scalar8 =
|
using scalar8 =
|
||||||
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
|
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
|
||||||
using half4 =
|
using half4 =
|
||||||
@@ -346,13 +346,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
// - Then the WG will move to another 8 K elements
|
// - Then the WG will move to another 8 K elements
|
||||||
// TODO: Logic below will only work when K is multiple of 8
|
// TODO: Logic below will only work when K is multiple of 8
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
|
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
#if defined(__gfx950__)
|
||||||
|
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
|
||||||
if (k_in >= min__(K * N, max_lds_len)) break;
|
#else
|
||||||
|
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
#endif
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@@ -360,9 +360,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
|
|
||||||
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
||||||
|
|
||||||
float sum[N][YTILE];
|
|
||||||
scalar8 sum4[N][YTILE];
|
|
||||||
|
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
// Each wave works on a single column of weight matrix.
|
// Each wave works on a single column of weight matrix.
|
||||||
// There are 16 waves per WG, and hence, each WG is
|
// There are 16 waves per WG, and hence, each WG is
|
||||||
@@ -386,44 +383,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
// YTILE represents how many column of weight matrix
|
// YTILE represents how many column of weight matrix
|
||||||
// are being worked on by each wave.
|
// are being worked on by each wave.
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
for (int i = 0; i < YTILE; i++)
|
float sum[N][YTILE] = {};
|
||||||
for (int n = 0; n < N; n++)
|
scalar8 sum4[N][YTILE] = {};
|
||||||
if constexpr (!use_mfma)
|
|
||||||
sum[n][i] = 0;
|
|
||||||
else
|
|
||||||
sum4[n][i] = {0, 0, 0, 0};
|
|
||||||
|
|
||||||
bigType bigA[N][UNRL];
|
|
||||||
bigType bigB[YTILE][UNRL];
|
|
||||||
//----------------------------------------------------
|
|
||||||
// Fetch weight matrix B in interleaved K-split!
|
|
||||||
// - Each thread (lane) is fetching 8 elements (A_Chunk)
|
|
||||||
// - Each wave will fetch 64*8=> 512 elements (1024B)
|
|
||||||
// - YTILE represents the number of column being serviced
|
|
||||||
// by wave
|
|
||||||
// - Loop for fetching weight matrix (B) are unrolled
|
|
||||||
//
|
|
||||||
// Fetch activation matrix A from LDS
|
|
||||||
// - Loop for fetching activation matrix (A) are unrolled
|
|
||||||
//
|
|
||||||
// Finally, do the matrix multiplication in an unrolled
|
|
||||||
// fashion. This provides lot of food for compiler
|
|
||||||
// scheduling.
|
|
||||||
//
|
|
||||||
// TODO: Logic below will only work when K is multiple of 8
|
|
||||||
//----------------------------------------------------
|
|
||||||
// for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
|
||||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||||
|
bigType bigA[N][UNRL] = {};
|
||||||
|
bigType bigB[YTILE][UNRL];
|
||||||
// Fetch the weight matrix from memory!
|
// Fetch the weight matrix from memory!
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||||
if (k_ >= K) break;
|
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||||
|
|
||||||
const scalar_t* B_ = &B[(m + 0) * K + k_];
|
|
||||||
for (int y = 0; y < YTILE; y++)
|
for (int y = 0; y < YTILE; y++)
|
||||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
|
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch activation matrix from either just LDS or from both LDS / memory
|
// Fetch activation matrix from either just LDS or from both LDS / memory
|
||||||
@@ -432,33 +405,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||||
if (k_ >= K) break;
|
if (k_ >= K) break;
|
||||||
|
|
||||||
// Fetch A activation matrix in interleaved fashion from LDS or memory
|
|
||||||
|
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do the matrix multiplication in interleaved manner
|
// Do the matrix multiplication in interleaved manner
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
|
||||||
if (k_ >= K) break;
|
|
||||||
// Do the matrix multiplication of activation and weight matrix
|
|
||||||
// - Remember the accumulation is happening for K-split of 64!
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t n = 0; n < N; n++) {
|
for (uint32_t n = 0; n < N; n++) {
|
||||||
#pragma unroll
|
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if constexpr (!use_mfma)
|
if constexpr (!use_mfma)
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||||
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
||||||
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
||||||
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
||||||
@@ -466,46 +426,44 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
__builtin_amdgcn_sched_barrier(0);
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
// Final reduction step using shuffle
|
// Final reduction step using shuffle
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
if constexpr (!use_mfma) {
|
if constexpr (!use_mfma) {
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
|
||||||
: "=v"(sum[n][y])
|
1); // row_shr8
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
1); // row_shr4
|
||||||
: "=v"(sum[n][y])
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
1); // row_shr2
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
||||||
: "=v"(sum[n][y])
|
1); // row_shr1
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
1); // ROW_BCAST15
|
||||||
: "=v"(sum[n][y])
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
1); // ROW_BCAST31
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
|
||||||
: "=v"(sum[n][y])
|
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
|
||||||
: "=v"(sum[n][y])
|
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == 63) {
|
||||||
for (int n = 0; n < N; n++) {
|
scalar_t biases[N][YTILE] = {};
|
||||||
for (int i = 0; i < YTILE; i++) {
|
if (BIAS)
|
||||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
for (int n = 0; n < N; n++) {
|
||||||
if (BIAS)
|
for (int y = 0; y < YTILE; y++) {
|
||||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
|
||||||
if (BIAS)
|
|
||||||
sum[n][i] +=
|
|
||||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
|
||||||
}
|
}
|
||||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
}
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
for (int y = 0; y < YTILE; y++) {
|
||||||
|
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||||
|
sum[n][y] += __half2float(biases[n][y]);
|
||||||
|
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||||
|
sum[n][y] += __bfloat162float(biases[n][y]);
|
||||||
|
}
|
||||||
|
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -514,45 +472,43 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
// float accm1 = 0;
|
/*float accm1 = 0;
|
||||||
// for (int i=0; i<64; i++)
|
for (int i=0; i<64; i++)
|
||||||
// accm1 += __shfl(sum4[n][y][i%4], i);
|
accm1 += __shfl(sum4[n][y][i%4], i);
|
||||||
|
sum4[n][y][0] = accm1;*/
|
||||||
float accm = sum4[n][y][0];
|
float accm = sum4[n][y][0];
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // row_shl1
|
||||||
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
1); // row_shl2
|
||||||
: "=v"(accm)
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
|
||||||
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
|
1); // row_shl3
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // row_shl4
|
||||||
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
|
1); // row_shl8
|
||||||
: "=v"(accm)
|
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
1); // row_shr15
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // ROW_BCAST15
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
|
1); // ROW_BCAST31
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
|
||||||
|
|
||||||
sum4[n][y][0] = accm;
|
sum4[n][y][0] = accm;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == 63) {
|
||||||
|
scalar_t biases[N][YTILE] = {};
|
||||||
|
if (BIAS)
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
for (int y = 0; y < YTILE; y++) {
|
||||||
|
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int i = 0; i < YTILE; i++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if (BIAS)
|
sum4[n][y][0] += __bfloat162float(biases[n][y]);
|
||||||
sum4[n][i][0] +=
|
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
|
||||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
|
||||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -563,8 +519,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
|
__global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
|
||||||
const int By, const scalar_t* B,
|
const int M, const int Bx, const int By,
|
||||||
|
const scalar_t* B,
|
||||||
const scalar_t* __restrict__ A,
|
const scalar_t* __restrict__ A,
|
||||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
@@ -577,8 +534,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
|
|||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||||
wvSplitK_hf_(const int K, const int M, const int Bx, const int By,
|
wvSplitK_hf_(const int K, const int Kbp, const int Kap, const int M,
|
||||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
const int Bx, const int By, const scalar_t* B,
|
||||||
|
const scalar_t* __restrict__ A,
|
||||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||||
@@ -601,13 +559,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
scalar8 h8;
|
scalar8 h8;
|
||||||
};
|
};
|
||||||
|
|
||||||
//----------------------------------------------------
|
|
||||||
// Reserving 64 KB of LDS to have 1 WG / CU
|
|
||||||
// Goal is to bring the activation matrix A to the LDS
|
|
||||||
// and use it across the lifetime of the work group
|
|
||||||
// TODO: When activation matrix is larger than 64 KB
|
|
||||||
// then this is not going to work!
|
|
||||||
//----------------------------------------------------
|
|
||||||
__shared__ scalar_t s[max_lds_len];
|
__shared__ scalar_t s[max_lds_len];
|
||||||
|
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
@@ -618,12 +569,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
commitColumn[i] = 1;
|
commitColumn[i] = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//----------------------------------------------------
|
|
||||||
// Indexing function into the column of weight matrix B
|
|
||||||
// Algorithm does 64 lane k-splitting / wave and uses
|
|
||||||
// WG ID and Thread ID to find the index.
|
|
||||||
//----------------------------------------------------
|
|
||||||
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
|
|
||||||
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
|
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
|
||||||
|
|
||||||
// Check whether there will be fragmentation!
|
// Check whether there will be fragmentation!
|
||||||
@@ -636,91 +581,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
m = startColumn;
|
m = startColumn;
|
||||||
}
|
}
|
||||||
|
|
||||||
//----------------------------------------------------
|
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||||
// Fetch the activation matrix to LDS
|
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||||
// Loop iteration:
|
#if defined(__gfx950__)
|
||||||
// - Each thread (lane) is fetching 8 elements (A_Chunk)
|
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
|
||||||
// - Each wave will fetch 64*8=> 512 elements
|
#else
|
||||||
// - Each WG will fetch 512 * 16 => 8K elements
|
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||||
// - Then the WG will move to another 8 K elements
|
#endif
|
||||||
// TODO: Logic below will only work when K is multiple of 8
|
|
||||||
//----------------------------------------------------
|
|
||||||
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
|
|
||||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
|
||||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
|
||||||
|
|
||||||
if (k_in >= min__(K * N, max_lds_len)) break;
|
|
||||||
|
|
||||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (threadIdx.y >= _WvPrGrp) return;
|
if (threadIdx.y >= _WvPrGrp) return;
|
||||||
|
|
||||||
float sum[N][YTILE];
|
|
||||||
scalar8 sum4[N][YTILE];
|
|
||||||
|
|
||||||
//----------------------------------------------------
|
|
||||||
// Each wave works on a single column of weight matrix.
|
|
||||||
// There are 16 waves per WG, and hence, each WG is
|
|
||||||
// working on 16 columns of weight matrix. Moreover,
|
|
||||||
// we tile in column direction by YTILE, so when YTILE=1
|
|
||||||
// the above math is right, however, when YTILE=2 then
|
|
||||||
// each wave will be working on 2 columns and WG will
|
|
||||||
// be working on 32 columns.
|
|
||||||
//
|
|
||||||
// Top level loop that makes WGs persistent!
|
|
||||||
// - WGs iterates across columns of weight matrix
|
|
||||||
// - Each wave within WG works on a given column(s)
|
|
||||||
// - After completing first set of columns, WGs start
|
|
||||||
// working on the next set of available columns
|
|
||||||
//----------------------------------------------------
|
|
||||||
while (m < M) {
|
while (m < M) {
|
||||||
//----------------------------------------------------
|
float sum[N][YTILE] = {};
|
||||||
// 'sum' accumulates the matrix A x B computation
|
scalar8 sum4[N][YTILE] = {};
|
||||||
// split across 64 lanes.
|
|
||||||
//
|
|
||||||
// YTILE represents how many column of weight matrix
|
|
||||||
// are being worked on by each wave.
|
|
||||||
//----------------------------------------------------
|
|
||||||
for (int i = 0; i < YTILE; i++)
|
|
||||||
for (int n = 0; n < N; n++)
|
|
||||||
if constexpr (!use_mfma)
|
|
||||||
sum[n][i] = 0;
|
|
||||||
else
|
|
||||||
sum4[n][i] = {0, 0, 0, 0};
|
|
||||||
|
|
||||||
bigType bigA[N][UNRL];
|
|
||||||
bigType bigB[YTILE][UNRL];
|
|
||||||
//----------------------------------------------------
|
|
||||||
// Fetch weight matrix B in interleaved K-split!
|
|
||||||
// - Each thread (lane) is fetching 8 elements (A_Chunk)
|
|
||||||
// - Each wave will fetch 64*8=> 512 elements (1024B)
|
|
||||||
// - YTILE represents the number of column being serviced
|
|
||||||
// by wave
|
|
||||||
// - Loop for fetching weight matrix (B) are unrolled
|
|
||||||
//
|
|
||||||
// Fetch activation matrix A from LDS
|
|
||||||
// - Loop for fetching activation matrix (A) are unrolled
|
|
||||||
//
|
|
||||||
// Finally, do the matrix multiplication in an unrolled
|
|
||||||
// fashion. This provides lot of food for compiler
|
|
||||||
// scheduling.
|
|
||||||
//
|
|
||||||
// TODO: Logic below will only work when K is multiple of 8
|
|
||||||
//----------------------------------------------------
|
|
||||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||||
|
bigType bigA[N][UNRL] = {};
|
||||||
|
bigType bigB[YTILE][UNRL];
|
||||||
// Fetch the weight matrix from memory!
|
// Fetch the weight matrix from memory!
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||||
if (k_ >= K) break;
|
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||||
|
for (int y = 0; y < YTILE; y++)
|
||||||
const scalar_t* B_ = &B[(m + 0) * K + k_];
|
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
|
||||||
for (int b = 0; b < YTILE; b++)
|
|
||||||
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch activation matrix from either just LDS or from both LDS / memory
|
// Fetch activation matrix from either just LDS or from both LDS / memory
|
||||||
@@ -729,36 +617,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||||
if (k_ >= K) break;
|
if (k_ >= K) break;
|
||||||
|
|
||||||
// Fetch A activation matrix in interleaved fashion from LDS or memory
|
|
||||||
|
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
if (k_ + K * n < max_lds_len)
|
if (k_ + Kap * n < max_lds_len)
|
||||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
|
||||||
else
|
else
|
||||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do the matrix multiplication in interleaved manner
|
// Do the matrix multiplication in interleaved manner
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t n = 0; n < N; n++) {
|
for (uint32_t n = 0; n < N; n++) {
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
|
||||||
if (k_ >= K) break;
|
|
||||||
// Do the matrix multiplication of activation and weight matrix
|
|
||||||
// - Remember the accumulation is happening for K-split of 64!
|
|
||||||
#pragma unroll
|
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if constexpr (!use_mfma)
|
if constexpr (!use_mfma)
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||||
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
||||||
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
||||||
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
||||||
@@ -773,40 +648,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
if constexpr (!use_mfma) {
|
if constexpr (!use_mfma) {
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
|
||||||
: "=v"(sum[n][y])
|
1); // row_shr8
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
1); // row_shr4
|
||||||
: "=v"(sum[n][y])
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
1); // row_shr2
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
||||||
: "=v"(sum[n][y])
|
1); // row_shr1
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
1); // ROW_BCAST15
|
||||||
: "=v"(sum[n][y])
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
1); // ROW_BCAST31
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
|
||||||
: "=v"(sum[n][y])
|
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
|
||||||
: "=v"(sum[n][y])
|
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == 63) {
|
||||||
|
scalar_t biases[N][YTILE] = {};
|
||||||
|
if (BIAS)
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
for (int y = 0; y < YTILE; y++) {
|
||||||
|
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int i = 0; i < YTILE; i++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if (commitColumn[i]) {
|
if (commitColumn[y]) {
|
||||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||||
if (BIAS)
|
sum[n][y] += __half2float(biases[n][y]);
|
||||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
|
||||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||||
if (BIAS)
|
sum[n][y] += __bfloat162float(biases[n][y]);
|
||||||
sum[n][i] +=
|
|
||||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
|
||||||
}
|
}
|
||||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -819,44 +692,39 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
// float accm1 = 0;
|
// float accm1 = 0;
|
||||||
// for (int i=0; i<64; i++)
|
// for (int i=0; i<64; i++)
|
||||||
// accm1 += __shfl(sum4[n][y][i%4], i);
|
// accm1 += __shfl(sum4[n][y][i%4], i);
|
||||||
|
|
||||||
float accm = sum4[n][y][0];
|
float accm = sum4[n][y][0];
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // row_shl1
|
||||||
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
1); // row_shl2
|
||||||
: "=v"(accm)
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
|
||||||
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
|
1); // row_shl3
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // row_shl4
|
||||||
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
|
1); // row_shl8
|
||||||
: "=v"(accm)
|
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
1); // row_shr15
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // ROW_BCAST15
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
|
1); // ROW_BCAST31
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
|
||||||
|
|
||||||
sum4[n][y][0] = accm;
|
sum4[n][y][0] = accm;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == 63) {
|
||||||
|
scalar_t biases[N][YTILE] = {};
|
||||||
|
if (BIAS)
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
for (int y = 0; y < YTILE; y++) {
|
||||||
|
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int i = 0; i < YTILE; i++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if (commitColumn[i]) {
|
if (commitColumn[y]) {
|
||||||
if (BIAS)
|
sum4[n][y][0] += __bfloat162float(biases[n][y]);
|
||||||
sum4[n][i][0] +=
|
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
|
||||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
|
||||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -880,9 +748,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
|
__global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
|
||||||
const int By, const scalar_t* B,
|
const int M, const int Bx, const int By,
|
||||||
const scalar_t* __restrict__ A,
|
const scalar_t* B, const scalar_t* __restrict__ A,
|
||||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
UNREACHABLE_CODE
|
UNREACHABLE_CODE
|
||||||
@@ -894,8 +762,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
|
|||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||||
wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By,
|
wvSplitK_hf_big_(const int K, const int Kbp, const int Kap, const int M,
|
||||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
const int Bx, const int By, const scalar_t* B,
|
||||||
|
const scalar_t* __restrict__ A,
|
||||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||||
@@ -966,13 +835,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
#define PCML
|
#define PCML
|
||||||
#ifndef PCML
|
#ifndef PCML
|
||||||
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
|
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||||
k += THRDS * WvPrGrp * A_CHUNK) {
|
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||||
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
#if defined(__gfx950__)
|
||||||
|
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
|
||||||
if (k_in >= min__(K * N, max_lds_len)) break;
|
#else
|
||||||
|
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||||
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
#endif
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
#endif
|
#endif
|
||||||
@@ -987,10 +856,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
? kFit
|
? kFit
|
||||||
: (kFit - kFit % TUC); // round up to multiple of TUC
|
: (kFit - kFit % TUC); // round up to multiple of TUC
|
||||||
// if (kFit == 0) kFit = TUC;
|
// if (kFit == 0) kFit = TUC;
|
||||||
kFit = min__(kFit, K);
|
kFit = min__(kFit, Kap);
|
||||||
|
|
||||||
float sum[N][YTILE];
|
|
||||||
scalar8 sum4[N][YTILE];
|
|
||||||
|
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
// Each wave works on a single column of weight matrix.
|
// Each wave works on a single column of weight matrix.
|
||||||
@@ -1021,15 +887,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
// YTILE represents how many column of weight matrix
|
// YTILE represents how many column of weight matrix
|
||||||
// are being worked on by each wave.
|
// are being worked on by each wave.
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
for (int i = 0; i < YTILE; i++)
|
float sum[N][YTILE] = {};
|
||||||
for (int n = 0; n < N; n++)
|
scalar8 sum4[N][YTILE] = {};
|
||||||
if constexpr (!use_mfma)
|
|
||||||
sum[n][i] = 0;
|
|
||||||
else
|
|
||||||
sum4[n][i] = {0, 0, 0, 0};
|
|
||||||
|
|
||||||
bigType bigA[N][UNRL];
|
|
||||||
bigType bigB[YTILE][UNRL];
|
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
// Fetch weight matrix B in interleaved K-split!
|
// Fetch weight matrix B in interleaved K-split!
|
||||||
// - Each thread (lane) is fetching 8 elements (A_Chunk)
|
// - Each thread (lane) is fetching 8 elements (A_Chunk)
|
||||||
@@ -1048,18 +908,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
// TODO: Logic below will only work when K is multiple of 8
|
// TODO: Logic below will only work when K is multiple of 8
|
||||||
//----------------------------------------------------
|
//----------------------------------------------------
|
||||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||||
|
bigType bigA[N][UNRL] = {};
|
||||||
|
bigType bigB[YTILE][UNRL];
|
||||||
|
|
||||||
#ifdef PCML
|
#ifdef PCML
|
||||||
if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS
|
if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS
|
||||||
if (k1 != 0) kBase += kFit;
|
if (k1 != 0) kBase += kFit;
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) {
|
for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) {
|
||||||
uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
||||||
if (kBase + kOff >= K) break;
|
if (kBase + kOff >= Kap) break;
|
||||||
if (kOff >= kFit) break;
|
if (kOff >= kFit) break;
|
||||||
for (uint32_t n = 0; n < N; n++) {
|
for (uint32_t n = 0; n < N; n++) {
|
||||||
uint32_t k_in = kBase + n * K + kOff;
|
uint32_t k_in = kBase + n * Kap + kOff;
|
||||||
uint32_t k_ot = n * kFit + kOff;
|
uint32_t k_ot = n * kFit + kOff;
|
||||||
|
#if defined(__gfx950__)
|
||||||
|
__builtin_amdgcn_global_load_lds((int*)(&A[k_in]), (int*)(&s[k_ot]),
|
||||||
|
16, 0, 0);
|
||||||
|
#else
|
||||||
*((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in]));
|
*((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in]));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -1072,11 +940,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||||
if (k_ >= K) break;
|
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||||
|
for (int y = 0; y < YTILE; y++)
|
||||||
const scalar_t* B_ = &B[(m + 0) * K + k_];
|
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
|
||||||
for (int b = 0; b < YTILE; b++)
|
|
||||||
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch activation matrix from either just LDS or from both LDS / memory
|
// Fetch activation matrix from either just LDS or from both LDS / memory
|
||||||
@@ -1085,17 +951,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||||
if (k_ >= K) break;
|
if (k_ >= K) break;
|
||||||
|
|
||||||
// Fetch A activation matrix in interleaved fashion from LDS or memory
|
|
||||||
|
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
#ifdef PCML
|
#ifdef PCML
|
||||||
bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n])));
|
bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n])));
|
||||||
#else
|
#else
|
||||||
if (k_ + K * n < 32 * 1024)
|
if (k_ + Kap * n < max_lds_len)
|
||||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
|
||||||
else
|
else
|
||||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1103,22 +966,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
// Do the matrix multiplication in interleaved manner
|
// Do the matrix multiplication in interleaved manner
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
|
||||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
|
||||||
if (k_ >= K) break;
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t n = 0; n < N; n++) {
|
for (uint32_t n = 0; n < N; n++) {
|
||||||
// Do the matrix multiplication of activation and weight matrix
|
|
||||||
// - Remember the accumulation is happening for K-split of 64!
|
|
||||||
#pragma unroll
|
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if constexpr (!use_mfma)
|
if constexpr (!use_mfma)
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
|
||||||
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
|
||||||
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
|
||||||
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
|
||||||
@@ -1141,40 +995,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
if constexpr (!use_mfma) {
|
if constexpr (!use_mfma) {
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
|
||||||
: "=v"(sum[n][y])
|
1); // row_shr8
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
1); // row_shr4
|
||||||
: "=v"(sum[n][y])
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
1); // row_shr2
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
||||||
: "=v"(sum[n][y])
|
1); // row_shr1
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
|
1); // ROW_BCAST15
|
||||||
: "=v"(sum[n][y])
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
1); // ROW_BCAST31
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
|
||||||
: "=v"(sum[n][y])
|
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
|
||||||
: "=v"(sum[n][y])
|
|
||||||
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == 63) {
|
||||||
|
scalar_t biases[N][YTILE] = {};
|
||||||
|
if (BIAS)
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
for (int y = 0; y < YTILE; y++) {
|
||||||
|
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int i = 0; i < YTILE; i++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if (commitColumn[i]) {
|
if (commitColumn[y]) {
|
||||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||||
if (BIAS)
|
sum[n][y] += __half2float(biases[n][y]);
|
||||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
|
||||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||||
if (BIAS)
|
sum[n][y] += __bfloat162float(biases[n][y]);
|
||||||
sum[n][i] +=
|
|
||||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
|
||||||
}
|
}
|
||||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1185,42 +1037,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
float accm = sum4[n][y][0];
|
float accm = sum4[n][y][0];
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // row_shl1
|
||||||
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
1); // row_shl2
|
||||||
: "=v"(accm)
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
|
||||||
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
|
1); // row_shl3
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // row_shl4
|
||||||
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
|
1); // row_shl8
|
||||||
: "=v"(accm)
|
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
1); // row_shr15
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
|
||||||
: "=v"(accm)
|
1); // ROW_BCAST15
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
|
||||||
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
|
1); // ROW_BCAST31
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
|
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
|
||||||
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
|
|
||||||
: "=v"(accm)
|
|
||||||
: "0"(accm), "v"(accm), "v"(accm));
|
|
||||||
|
|
||||||
sum4[n][y][0] = accm;
|
sum4[n][y][0] = accm;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == 63) {
|
||||||
|
scalar_t biases[N][YTILE] = {};
|
||||||
|
if (BIAS)
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
for (int y = 0; y < YTILE; y++) {
|
||||||
|
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int i = 0; i < YTILE; i++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if (commitColumn[i]) {
|
if (commitColumn[y]) {
|
||||||
if (BIAS)
|
sum4[n][y][0] += __bfloat162float(biases[n][y]);
|
||||||
sum4[n][i][0] +=
|
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
|
||||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
|
||||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1244,8 +1092,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
|
__global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
|
||||||
const int By, const scalar_t* B,
|
const int M, const int Bx, const int By,
|
||||||
|
const scalar_t* B,
|
||||||
const scalar_t* __restrict__ A,
|
const scalar_t* __restrict__ A,
|
||||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
@@ -1272,6 +1121,8 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
auto M_in = in_a.size(0);
|
auto M_in = in_a.size(0);
|
||||||
auto K_in = in_a.size(1);
|
auto K_in = in_a.size(1);
|
||||||
auto N_in = in_b.size(0);
|
auto N_in = in_b.size(0);
|
||||||
|
auto Kap_in = in_a.stride(0);
|
||||||
|
auto Kbp_in = in_b.stride(0);
|
||||||
auto Bx_in =
|
auto Bx_in =
|
||||||
(in_bias.has_value() && in_bias->numel() > 0)
|
(in_bias.has_value() && in_bias->numel() > 0)
|
||||||
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
||||||
@@ -1296,27 +1147,30 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
const int max_lds_len = get_lds_size() / 2;
|
const int max_lds_len = get_lds_size() / 2;
|
||||||
|
|
||||||
#define WVSPLITK(_YTILE, _UNRL, _N) \
|
#define WVSPLITK(_YTILE, _UNRL, _N) \
|
||||||
{ \
|
{ \
|
||||||
dim3 block(64, 16); \
|
dim3 block(64, 16); \
|
||||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
|
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
|
||||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
|
if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
|
||||||
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
||||||
else if (K_in * N_in <= max_lds_len * 1.2) \
|
CuCount); \
|
||||||
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
else if (Kbp_in * N_in <= max_lds_len * 1.2) \
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||||
else \
|
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
||||||
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
CuCount); \
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
else \
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
|
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||||
|
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
||||||
|
CuCount); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define WVSPLIT_TILE(_sYT, __N) \
|
#define WVSPLIT_TILE(_sYT, __N) \
|
||||||
{ \
|
{ \
|
||||||
bool fit_lds = (K_in * N_in <= max_lds_len); \
|
bool fit_lds = (Kbp_in * N_in <= max_lds_len); \
|
||||||
if (_sYT <= 1) \
|
if (_sYT <= 1) \
|
||||||
WVSPLITK(1, 4, __N) \
|
WVSPLITK(1, 4, __N) \
|
||||||
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
|
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
|
||||||
|
|||||||
@@ -426,6 +426,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
|
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
|
||||||
|
ops.def(
|
||||||
|
"mxfp8_experts_quant("
|
||||||
|
" Tensor input, Tensor problem_sizes, Tensor expert_offsets,"
|
||||||
|
" Tensor blockscale_offsets, Tensor! quant_output, Tensor! scale_factor)"
|
||||||
|
" -> ()");
|
||||||
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
|
// Expert-specialization mxfp8 blockscaled grouped GEMM (SM100+).
|
||||||
|
ops.def(
|
||||||
|
"cutlass_mxfp8_grouped_mm("
|
||||||
|
" Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor! out,"
|
||||||
|
" Tensor problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets)"
|
||||||
|
" -> ()");
|
||||||
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
// quantization, as well as bias
|
// quantization, as well as bias
|
||||||
ops.def(
|
ops.def(
|
||||||
@@ -489,19 +505,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
|
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
|
||||||
|
|
||||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||||
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
// GEMM in batched expert format. It takes expert_num_tokens
|
||||||
// as an input, and computes expert_offsets (token start indices of each
|
// as an input, and computes expert_offsets (token start indices of each
|
||||||
// expert). In addition to this, it computes problem sizes for each expert's
|
// expert). In addition to this, it computes problem sizes for each expert's
|
||||||
// multiplication used by the two mms called from fused MoE operation.
|
// multiplication used by the two mms called from fused MoE operation.
|
||||||
ops.def(
|
ops.def(
|
||||||
"get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, "
|
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
|
||||||
" Tensor! problem_sizes1, "
|
" Tensor! problem_sizes1, "
|
||||||
" Tensor! problem_sizes2, "
|
" Tensor! problem_sizes2, "
|
||||||
" Tensor expert_num_tokens, "
|
" Tensor expert_num_tokens, "
|
||||||
" int num_local_experts, int padded_m, "
|
" int num_local_experts, int padded_m, "
|
||||||
" int n, int k) -> ()");
|
" int n, int k) -> ()");
|
||||||
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
|
ops.impl("get_cutlass_batched_moe_mm_data", torch::kCUDA,
|
||||||
&get_cutlass_pplx_moe_mm_data);
|
&get_cutlass_batched_moe_mm_data);
|
||||||
|
|
||||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||||
ops.def(
|
ops.def(
|
||||||
@@ -640,7 +656,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"int block_size,"
|
"int block_size,"
|
||||||
"Tensor? block_idx_first_scheduled_token,"
|
"Tensor? block_idx_first_scheduled_token,"
|
||||||
"Tensor? block_idx_last_scheduled_token,"
|
"Tensor? block_idx_last_scheduled_token,"
|
||||||
"Tensor? initial_state_idx) -> ()");
|
"Tensor? initial_state_idx,"
|
||||||
|
"Tensor? cu_chunk_seqlen,"
|
||||||
|
"Tensor? last_chunk_indices) -> ()");
|
||||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||||
|
|
||||||
// Hadamard transforms
|
// Hadamard transforms
|
||||||
|
|||||||
@@ -262,7 +262,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
|
|
||||||
# Build the vLLM wheel
|
# Build the vLLM wheel
|
||||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||||
|
# AWS credentials mounted at ~/.aws/credentials for sccache S3 auth (optional)
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
--mount=type=secret,id=aws-credentials,target=/root/.aws/credentials,required=false \
|
||||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||||
echo "Installing sccache..." \
|
echo "Installing sccache..." \
|
||||||
&& case "${TARGETPLATFORM}" in \
|
&& case "${TARGETPLATFORM}" in \
|
||||||
@@ -308,7 +310,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
|||||||
#################### CSRC BUILD IMAGE ####################
|
#################### CSRC BUILD IMAGE ####################
|
||||||
|
|
||||||
#################### EXTENSIONS BUILD IMAGE ####################
|
#################### EXTENSIONS BUILD IMAGE ####################
|
||||||
# Build DeepGEMM, pplx-kernels, DeepEP - runs in PARALLEL with csrc-build
|
# Build DeepGEMM, DeepEP - runs in PARALLEL with csrc-build
|
||||||
# This stage is independent and doesn't affect csrc cache
|
# This stage is independent and doesn't affect csrc cache
|
||||||
FROM base AS extensions-build
|
FROM base AS extensions-build
|
||||||
ARG CUDA_VERSION
|
ARG CUDA_VERSION
|
||||||
@@ -335,10 +337,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
# Ensure the wheel dir exists so COPY won't fail when DeepGEMM is skipped
|
# Ensure the wheel dir exists so COPY won't fail when DeepGEMM is skipped
|
||||||
RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped
|
RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped
|
||||||
|
|
||||||
# Build pplx-kernels and DeepEP wheels
|
# Build DeepEP wheels
|
||||||
COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh
|
COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh
|
||||||
# Defaults moved here from tools/ep_kernels/install_python_libraries.sh for centralized version management
|
# Defaults moved here from tools/ep_kernels/install_python_libraries.sh for centralized version management
|
||||||
ARG PPLX_COMMIT_HASH=12cecfd
|
|
||||||
ARG DEEPEP_COMMIT_HASH=73b6ea4
|
ARG DEEPEP_COMMIT_HASH=73b6ea4
|
||||||
ARG NVSHMEM_VER
|
ARG NVSHMEM_VER
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
@@ -347,7 +348,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
/tmp/install_python_libraries.sh \
|
/tmp/install_python_libraries.sh \
|
||||||
--workspace /tmp/ep_kernels_workspace \
|
--workspace /tmp/ep_kernels_workspace \
|
||||||
--mode wheel \
|
--mode wheel \
|
||||||
${PPLX_COMMIT_HASH:+--pplx-ref "$PPLX_COMMIT_HASH"} \
|
|
||||||
${DEEPEP_COMMIT_HASH:+--deepep-ref "$DEEPEP_COMMIT_HASH"} \
|
${DEEPEP_COMMIT_HASH:+--deepep-ref "$DEEPEP_COMMIT_HASH"} \
|
||||||
${NVSHMEM_VER:+--nvshmem-ver "$NVSHMEM_VER"} && \
|
${NVSHMEM_VER:+--nvshmem-ver "$NVSHMEM_VER"} && \
|
||||||
find /tmp/ep_kernels_workspace/nvshmem -name '*.a' -delete
|
find /tmp/ep_kernels_workspace/nvshmem -name '*.a' -delete
|
||||||
@@ -676,7 +676,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH
|
# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH
|
||||||
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
# Install EP kernels wheels (pplx-kernels and DeepEP) that have been built in the `build` stage
|
# Install EP kernels wheels (DeepEP) that have been built in the `build` stage
|
||||||
RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm-workspace/ep_kernels/dist \
|
RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm-workspace/ep_kernels/dist \
|
||||||
--mount=type=cache,target=/root/.cache/uv \
|
--mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system ep_kernels/dist/*.whl --verbose \
|
uv pip install --system ep_kernels/dist/*.whl --verbose \
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ ARG PYTHON_VERSION=3.12
|
|||||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/xpu"
|
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/xpu"
|
||||||
|
|
||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
|
||||||
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \
|
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
add-apt-repository -y ppa:kobuk-team/intel-graphics
|
|
||||||
|
|
||||||
RUN apt clean && apt-get update -y && \
|
RUN apt clean && apt-get update -y && \
|
||||||
apt-get install -y --no-install-recommends --fix-missing \
|
apt-get install -y --no-install-recommends --fix-missing \
|
||||||
@@ -28,9 +27,22 @@ RUN apt clean && apt-get update -y && \
|
|||||||
python3-pip
|
python3-pip
|
||||||
|
|
||||||
RUN apt update && apt upgrade -y && \
|
RUN apt update && apt upgrade -y && \
|
||||||
apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing intel-ocloc && \
|
|
||||||
apt install -y intel-oneapi-compiler-dpcpp-cpp-2025.3
|
apt install -y intel-oneapi-compiler-dpcpp-cpp-2025.3
|
||||||
|
|
||||||
|
# Install UMD
|
||||||
|
RUN mkdir neo && \
|
||||||
|
cd neo && \
|
||||||
|
wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.24.8/intel-igc-core-2_2.24.8+20344_amd64.deb && \
|
||||||
|
wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.24.8/intel-igc-opencl-2_2.24.8+20344_amd64.deb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/intel-ocloc_25.48.36300.8-0_amd64.deb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/intel-opencl-icd_25.48.36300.8-0_amd64.deb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/libigdgmm12_22.8.2_amd64.deb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/libze-intel-gpu1_25.48.36300.8-0_amd64.deb && \
|
||||||
|
wget https://github.com/oneapi-src/level-zero/releases/download/v1.26.0/level-zero_1.26.0+u24.04_amd64.deb && \
|
||||||
|
dpkg -i *.deb && \
|
||||||
|
cd .. && \
|
||||||
|
rm -rf neo
|
||||||
|
|
||||||
ENV PATH="/root/.local/bin:$PATH"
|
ENV PATH="/root/.local/bin:$PATH"
|
||||||
ENV VIRTUAL_ENV="/opt/venv"
|
ENV VIRTUAL_ENV="/opt/venv"
|
||||||
ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
|
ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
|
||||||
@@ -103,9 +115,57 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
# install development dependencies (for testing)
|
# install development dependencies (for testing)
|
||||||
RUN uv pip install -e tests/vllm_test_utils
|
RUN uv pip install -e tests/vllm_test_utils
|
||||||
|
|
||||||
# install nixl from source code
|
# install NIXL and UCX from source code
|
||||||
ENV NIXL_VERSION=0.7.0
|
ARG UCX_VERSION=e5d98879705239d254ede40b4a52891850cb5349
|
||||||
RUN python /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
|
ARG NIXL_VERSION=0.7.0
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
pciutils \
|
||||||
|
net-tools \
|
||||||
|
iproute2 \
|
||||||
|
hwloc \
|
||||||
|
numactl \
|
||||||
|
wget \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
build-essential \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
libtool \
|
||||||
|
pkg-config \
|
||||||
|
rdma-core \
|
||||||
|
libibverbs-dev \
|
||||||
|
ibverbs-utils \
|
||||||
|
libibverbs1 \
|
||||||
|
librdmacm-dev \
|
||||||
|
librdmacm1 \
|
||||||
|
libibumad-dev \
|
||||||
|
libibumad3 \
|
||||||
|
libibmad-dev \
|
||||||
|
libibmad5 \
|
||||||
|
infiniband-diags \
|
||||||
|
perftest \
|
||||||
|
ibutils \
|
||||||
|
libmlx5-1 \
|
||||||
|
libmlx4-1 \
|
||||||
|
ibverbs-providers \
|
||||||
|
librdmacm1t64
|
||||||
|
|
||||||
|
ENV PKG_CONFIG_PATH=/tmp/ucx_install/lib/pkgconfig:${PKG_CONFIG_PATH}
|
||||||
|
ENV LD_LIBRARY_PATH=/tmp/ucx_install/lib:${LD_LIBRARY_PATH}
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
git clone https://github.com/openucx/ucx /tmp/ucx_source && \
|
||||||
|
cd /tmp/ucx_source && git checkout "${UCX_VERSION}" && \
|
||||||
|
bash autogen.sh && \
|
||||||
|
./configure --prefix=/tmp/ucx_install --with-ze=yes --enable-examples --enable-mt && \
|
||||||
|
make CFLAGS="-Wno-error=incompatible-pointer-types" -j8 && make install && \
|
||||||
|
git clone https://github.com/ai-dynamo/nixl /tmp/nixl_source && \
|
||||||
|
cd /tmp/nixl_source && git checkout "${NIXL_VERSION}" && \
|
||||||
|
cd /tmp/nixl_source && \
|
||||||
|
uv pip install --upgrade meson pybind11 patchelf && \
|
||||||
|
uv pip install -r requirements.txt && \
|
||||||
|
uv pip install . && \
|
||||||
|
rm -rf /tmp/ucx_source /tmp/nixl_source
|
||||||
|
|
||||||
# FIX triton
|
# FIX triton
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
|||||||
@@ -52,9 +52,6 @@
|
|||||||
"DEEPGEMM_GIT_REF": {
|
"DEEPGEMM_GIT_REF": {
|
||||||
"default": "477618cd51baffca09c4b0b87e97c03fe827ef03"
|
"default": "477618cd51baffca09c4b0b87e97c03fe827ef03"
|
||||||
},
|
},
|
||||||
"PPLX_COMMIT_HASH": {
|
|
||||||
"default": "12cecfd"
|
|
||||||
},
|
|
||||||
"DEEPEP_COMMIT_HASH": {
|
"DEEPEP_COMMIT_HASH": {
|
||||||
"default": "73b6ea4"
|
"default": "73b6ea4"
|
||||||
},
|
},
|
||||||
|
|||||||
BIN
docs/assets/design/model_runner_v2/async_no_race_condition.png
Normal file
BIN
docs/assets/design/model_runner_v2/async_no_race_condition.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 130 KiB |
BIN
docs/assets/design/model_runner_v2/async_race_condition.png
Normal file
BIN
docs/assets/design/model_runner_v2/async_race_condition.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 128 KiB |
BIN
docs/assets/design/model_runner_v2/async_sched.png
Normal file
BIN
docs/assets/design/model_runner_v2/async_sched.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 254 KiB |
BIN
docs/assets/design/model_runner_v2/persistent_batch_mrv2.png
Normal file
BIN
docs/assets/design/model_runner_v2/persistent_batch_mrv2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
docs/assets/design/model_runner_v2/persistent_batch_v1.png
Normal file
BIN
docs/assets/design/model_runner_v2/persistent_batch_v1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 65 KiB |
@@ -72,7 +72,7 @@ Follow these steps to run the script:
|
|||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Determine where you want to save the results, and pass that to `--output-dir`.
|
5. Set `--output-dir` and optionally `--experiment-name` to control where to save the results.
|
||||||
|
|
||||||
Example command:
|
Example command:
|
||||||
|
|
||||||
@@ -82,7 +82,8 @@ vllm bench sweep serve \
|
|||||||
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
|
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
|
||||||
--serve-params benchmarks/serve_hparams.json \
|
--serve-params benchmarks/serve_hparams.json \
|
||||||
--bench-params benchmarks/bench_hparams.json \
|
--bench-params benchmarks/bench_hparams.json \
|
||||||
-o benchmarks/results
|
--output-dir benchmarks/results \
|
||||||
|
--experiment-name demo
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, each parameter combination is benchmarked 3 times to make the results more reliable. You can adjust the number of runs by setting `--num-runs`.
|
By default, each parameter combination is benchmarked 3 times to make the results more reliable. You can adjust the number of runs by setting `--num-runs`.
|
||||||
@@ -102,33 +103,41 @@ By default, each parameter combination is benchmarked 3 times to make the result
|
|||||||
!!! tip
|
!!! tip
|
||||||
You can use the `--resume` option to continue the parameter sweep if an unexpected error occurs, e.g., timeout when connecting to HF Hub.
|
You can use the `--resume` option to continue the parameter sweep if an unexpected error occurs, e.g., timeout when connecting to HF Hub.
|
||||||
|
|
||||||
### SLA Scanner
|
### Workload Explorer
|
||||||
|
|
||||||
`vllm bench sweep serve_sla` is a variant of `vllm bench sweep serve` that scans through values of request rate or concurrency (choose using `--sla-variable`) in order to find the tradeoff between latency and throughput. The results can then be [visualized](#visualization) to determine the feasible SLAs.
|
`vllm bench sweep serve_workload` is a variant of `vllm bench sweep serve` that explores different workload levels in order to find the tradeoff between latency and throughput. The results can also be [visualized](#visualization) to determine the feasible SLAs.
|
||||||
|
|
||||||
|
The workload can be expressed in terms of request rate or concurrency (choose using `--workload-var`).
|
||||||
|
|
||||||
Example command:
|
Example command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
vllm bench sweep serve_sla \
|
vllm bench sweep serve_workload \
|
||||||
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
|
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
|
||||||
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100' \
|
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100' \
|
||||||
|
--workload-var max_concurrency \
|
||||||
--serve-params benchmarks/serve_hparams.json \
|
--serve-params benchmarks/serve_hparams.json \
|
||||||
--bench-params benchmarks/bench_hparams.json
|
--bench-params benchmarks/bench_hparams.json \
|
||||||
-o benchmarks/results
|
--num-runs 1 \
|
||||||
|
--output-dir benchmarks/results \
|
||||||
|
--experiment-name demo
|
||||||
```
|
```
|
||||||
|
|
||||||
The algorithm for scanning through different values of `sla_variable` can be summarized as follows:
|
The algorithm for exploring different workload levels can be summarized as follows:
|
||||||
|
|
||||||
1. Run the benchmark once with `sla_variable = 1` to simulate serial inference. This results in the lowest possible latency and throughput.
|
1. Run the benchmark by sending requests one at a time (serial inference, lowest workload). This results in the lowest possible latency and throughput.
|
||||||
2. Run the benchmark once with `sla_variable = num_prompts` to simulate batch inference over the whole dataset. This results in the highest possible latency and throughput.
|
2. Run the benchmark by sending all requests at once (batch inference, highest workload). This results in the highest possible latency and throughput.
|
||||||
3. Estimate the maximum value of `sla_variable` that can be supported by the server without oversaturating it.
|
3. Estimate the value of `workload_var` corresponding to Step 2.
|
||||||
4. Run the benchmark over intermediate values of `sla_variable` uniformly using the remaining iterations.
|
4. Run the benchmark over intermediate values of `workload_var` uniformly using the remaining iterations.
|
||||||
|
|
||||||
You can override the number of iterations in the algorithm by setting `--sla-iters`.
|
You can override the number of iterations in the algorithm by setting `--workload-iters`.
|
||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
This is our equivalent of [GuideLLM's `--profile sweep`](https://github.com/vllm-project/guidellm/blob/v0.5.3/src/guidellm/benchmark/profiles.py#L575).
|
This is our equivalent of [GuideLLM's `--profile sweep`](https://github.com/vllm-project/guidellm/blob/v0.5.3/src/guidellm/benchmark/profiles.py#L575).
|
||||||
|
|
||||||
|
In general, `--workload-var max_concurrency` produces more reliable results because it directly controls the workload imposed on the vLLM engine.
|
||||||
|
Nevertheless, we default to `--workload-var request_rate` to maintain similar behavior as GuideLLM.
|
||||||
|
|
||||||
## Startup Benchmark
|
## Startup Benchmark
|
||||||
|
|
||||||
`vllm bench sweep startup` runs `vllm bench startup` across parameter combinations to compare cold/warm startup time for different engine settings.
|
`vllm bench sweep startup` runs `vllm bench startup` across parameter combinations to compare cold/warm startup time for different engine settings.
|
||||||
@@ -179,7 +188,8 @@ vllm bench sweep startup \
|
|||||||
--startup-cmd 'vllm bench startup --model Qwen/Qwen3-0.6B' \
|
--startup-cmd 'vllm bench startup --model Qwen/Qwen3-0.6B' \
|
||||||
--serve-params benchmarks/serve_hparams.json \
|
--serve-params benchmarks/serve_hparams.json \
|
||||||
--startup-params benchmarks/startup_hparams.json \
|
--startup-params benchmarks/startup_hparams.json \
|
||||||
-o benchmarks/results
|
--output-dir benchmarks/results \
|
||||||
|
--experiment-name demo
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! important
|
!!! important
|
||||||
@@ -194,26 +204,34 @@ vllm bench sweep startup \
|
|||||||
|
|
||||||
Control the variables to plot via `--var-x` and `--var-y`, optionally applying `--filter-by` and `--bin-by` to the values. The plot is organized according to `--fig-by`, `--row-by`, `--col-by`, and `--curve-by`.
|
Control the variables to plot via `--var-x` and `--var-y`, optionally applying `--filter-by` and `--bin-by` to the values. The plot is organized according to `--fig-by`, `--row-by`, `--col-by`, and `--curve-by`.
|
||||||
|
|
||||||
Example commands for visualizing [SLA Scanner](#sla-scanner) results:
|
Example commands for visualizing [Workload Explorer](#workload-explorer) results:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Latency increases as the request rate increases
|
EXPERIMENT_DIR=${1:-"benchmarks/results/demo"}
|
||||||
vllm bench sweep plot benchmarks/results/<timestamp> \
|
|
||||||
--var-x request_rate \
|
# Latency increases as the workload increases
|
||||||
--var-y p99_ttft_ms \
|
vllm bench sweep plot $EXPERIMENT_DIR \
|
||||||
--row-by random_input_len \
|
--var-x max_concurrency \
|
||||||
--col-by random_output_len \
|
--var-y median_ttft_ms \
|
||||||
|
--col-by _benchmark_name \
|
||||||
--curve-by max_num_seqs,max_num_batched_tokens \
|
--curve-by max_num_seqs,max_num_batched_tokens \
|
||||||
--filter-by 'request_rate<=128'
|
--fig-name latency_curve
|
||||||
|
|
||||||
|
# Throughput saturates as workload increases
|
||||||
|
vllm bench sweep plot $EXPERIMENT_DIR \
|
||||||
|
--var-x max_concurrency \
|
||||||
|
--var-y total_token_throughput \
|
||||||
|
--col-by _benchmark_name \
|
||||||
|
--curve-by max_num_seqs,max_num_batched_tokens \
|
||||||
|
--fig-name throughput_curve
|
||||||
|
|
||||||
# Tradeoff between latency and throughput
|
# Tradeoff between latency and throughput
|
||||||
vllm bench sweep plot benchmarks/results/<timestamp> \
|
vllm bench sweep plot $EXPERIMENT_DIR \
|
||||||
--var-x request_throughput \
|
--var-x total_token_throughput \
|
||||||
--var-y median_ttft_ms \
|
--var-y median_ttft_ms \
|
||||||
--row-by random_input_len \
|
--col-by _benchmark_name \
|
||||||
--col-by random_output_len \
|
|
||||||
--curve-by max_num_seqs,max_num_batched_tokens \
|
--curve-by max_num_seqs,max_num_batched_tokens \
|
||||||
--filter-by 'request_rate<=128'
|
--fig-name latency_throughput
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
@@ -233,7 +251,9 @@ Higher concurrency or batch size can raise GPU efficiency (per-GPU), but can add
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
vllm bench sweep plot_pareto benchmarks/results/<timestamp> \
|
EXPERIMENT_DIR=${1:-"benchmarks/results/demo"}
|
||||||
|
|
||||||
|
vllm bench sweep plot_pareto $EXPERIMENT_DIR \
|
||||||
--label-by max_concurrency,tensor_parallel_size,pipeline_parallel_size
|
--label-by max_concurrency,tensor_parallel_size,pipeline_parallel_size
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
# vllm bench sweep serve_sla
|
|
||||||
|
|
||||||
## JSON CLI Arguments
|
|
||||||
|
|
||||||
--8<-- "docs/cli/json_tip.inc.md"
|
|
||||||
|
|
||||||
## Arguments
|
|
||||||
|
|
||||||
--8<-- "docs/generated/argparse/bench_sweep_serve_sla.inc.md"
|
|
||||||
9
docs/cli/bench/sweep/serve_workload.md
Normal file
9
docs/cli/bench/sweep/serve_workload.md
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# vllm bench sweep serve_workload
|
||||||
|
|
||||||
|
## JSON CLI Arguments
|
||||||
|
|
||||||
|
--8<-- "docs/cli/json_tip.inc.md"
|
||||||
|
|
||||||
|
## Arguments
|
||||||
|
|
||||||
|
--8<-- "docs/generated/argparse/bench_sweep_serve_workload.inc.md"
|
||||||
@@ -168,17 +168,18 @@ Priority is **1 = highest** (tried first).
|
|||||||
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
|
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
|
||||||
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
|
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
|
||||||
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
|
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
|
||||||
|
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
|
||||||
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
||||||
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
||||||
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||||
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A |
|
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | All | N/A |
|
||||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A |
|
||||||
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
||||||
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
||||||
|
|
||||||
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
|
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
|
||||||
>
|
>
|
||||||
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise.
|
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise.
|
||||||
|
|
||||||
## MLA (Multi-head Latent Attention) Backends
|
## MLA (Multi-head Latent Attention) Backends
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ For example:
|
|||||||
--8<-- "vllm/model_executor/layers/attention/mm_encoder_attention.py:mm_encoder_attn"
|
--8<-- "vllm/model_executor/layers/attention/mm_encoder_attention.py:mm_encoder_attn"
|
||||||
|
|
||||||
--8<-- "vllm/model_executor/layers/mla.py:multi_head_latent_attention"
|
--8<-- "vllm/model_executor/layers/mla.py:multi_head_latent_attention"
|
||||||
|
|
||||||
|
--8<-- "vllm/model_executor/models/deepencoder.py:rel_pos_attention"
|
||||||
```
|
```
|
||||||
|
|
||||||
**2. Activation:**
|
**2. Activation:**
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ The current implementation has all `dbo_yield` and `dbo_maybe_run_recv_hook` cal
|
|||||||
|
|
||||||
The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization.
|
The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization.
|
||||||
|
|
||||||
The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalize` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel.
|
The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalizeModular` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel.
|
||||||
|
|
||||||
The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists.
|
The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists.
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ Based on the format of the input activations, FusedMoE implementations are broad
|
|||||||
The input activation format completely depends on the All2All Dispatch being used.
|
The input activation format completely depends on the All2All Dispatch being used.
|
||||||
|
|
||||||
* In the Contiguous variant, the All2All Dispatch returns the activations as a contiguous tensor of shape (M, K) along with TopK Ids and TopK weights of shape (M, num_topk). Look at `DeepEPHTPrepareAndFinalize` for an example.
|
* In the Contiguous variant, the All2All Dispatch returns the activations as a contiguous tensor of shape (M, K) along with TopK Ids and TopK weights of shape (M, num_topk). Look at `DeepEPHTPrepareAndFinalize` for an example.
|
||||||
* In the Batched variant, the All2All Dispatch returns the activations as a tensor of shape (num_experts, max_tokens, K). Here, the activations/tokens that subscribe to the same expert are batched together. Note that not all entries of the tensor are valid. The activations tensor is typically accompanied by an `expert_num_tokens` tensor of size `num_experts`, where `expert_num_tokens[i]` indicates the number of valid tokens that subscribe to the ith expert. Look at `PplxPrepareAndFinalize` or `DeepEPLLPrepareAndFinalize` for an example.
|
* In the Batched variant, the All2All Dispatch returns the activations as a tensor of shape (num_experts, max_tokens, K). Here, the activations/tokens that subscribe to the same expert are batched together. Note that not all entries of the tensor are valid. The activations tensor is typically accompanied by an `expert_num_tokens` tensor of size `num_experts`, where `expert_num_tokens[i]` indicates the number of valid tokens that subscribe to the ith expert. Look at `DeepEPLLPrepareAndFinalize` for an example.
|
||||||
|
|
||||||
The FusedMoE operation is generally made of multiple operations, in both the Contiguous and Batched variants, as described in the diagrams below
|
The FusedMoE operation is generally made of multiple operations, in both the Contiguous and Batched variants, as described in the diagrams below
|
||||||
|
|
||||||
@@ -37,31 +37,31 @@ The rest of the document will focus on the Contiguous / Non-Batched case. Extrap
|
|||||||
FusedMoEModularKernel splits the FusedMoE operation into 3 parts,
|
FusedMoEModularKernel splits the FusedMoE operation into 3 parts,
|
||||||
|
|
||||||
1. TopKWeightAndReduce
|
1. TopKWeightAndReduce
|
||||||
2. FusedMoEPrepareAndFinalize
|
2. FusedMoEPrepareAndFinalizeModular
|
||||||
3. FusedMoEPermuteExpertsUnpermute
|
3. FusedMoEExpertsModular
|
||||||
|
|
||||||
### TopKWeightAndReduce
|
### TopKWeightAndReduce
|
||||||
|
|
||||||
The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEPermuteExpertsUnpermute` is responsible for the Unpermute and `FusedMoEPrepareAndFinalize` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEPermuteExpertsUnpermute`. But some implementations choose to do it `FusedMoEPrepareAndFinalize`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class.
|
The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEExpertsModular` is responsible for the Unpermute and `FusedMoEPrepareAndFinalizeModular` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEExpertsModular`. But some implementations choose to do it `FusedMoEPrepareAndFinalizeModular`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class.
|
||||||
|
|
||||||
Please find the implementations of TopKWeightAndReduce [here](../../vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py).
|
Please find the implementations of TopKWeightAndReduce [here](../../vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py).
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method.
|
`FusedMoEPrepareAndFinalizeModular::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method.
|
||||||
The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExpertsUnpermute` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens.
|
The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEExpertsModular` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens.
|
||||||
|
|
||||||
* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceNoOp` if the `FusedMoEPermuteExpertsUnpermute` implementation does the weight application and reduction itself.
|
* `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceNoOp` if the `FusedMoEExpertsModular` implementation does the weight application and reduction itself.
|
||||||
* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEPermuteExpertsUnpermute` implementation needs the `FusedMoEPrepareAndFinalize::finalize()` to do the weight application and reduction.
|
* `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEExpertsModular` implementation needs the `FusedMoEPrepareAndFinalizeModular::finalize()` to do the weight application and reduction.
|
||||||
|
|
||||||
### FusedMoEPrepareAndFinalize
|
### FusedMoEPrepareAndFinalizeModular
|
||||||
|
|
||||||
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
|
The `FusedMoEPrepareAndFinalizeModular` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
|
||||||
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
|
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalizeModular` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
### FusedMoEPermuteExpertsUnpermute
|
### FusedMoEExpertsModular
|
||||||
|
|
||||||
The `FusedMoEPermuteExpertsUnpermute` class is where the crux of the MoE operations happen. The `FusedMoEPermuteExpertsUnpermute` abstract class exposes a few important functions,
|
The `FusedMoEExpertsModular` class is where the crux of the MoE operations happen. The `FusedMoEExpertsModular` abstract class exposes a few important functions,
|
||||||
|
|
||||||
* apply()
|
* apply()
|
||||||
* workspace_shapes()
|
* workspace_shapes()
|
||||||
@@ -81,25 +81,25 @@ The `apply` method is where the implementations perform
|
|||||||
|
|
||||||
#### workspace_shapes()
|
#### workspace_shapes()
|
||||||
|
|
||||||
The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEPermuteExpertsUnpermute::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation.
|
The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEExpertsModular::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation.
|
||||||
|
|
||||||
#### finalize_weight_and_reduce_impl()
|
#### finalize_weight_and_reduce_impl()
|
||||||
|
|
||||||
It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEPermuteExpertsUnpermute::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section.
|
It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEExpertsModular::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section.
|
||||||
`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalize::finalize()` to use.
|
`FusedMoEExpertsModular::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalizeModular::finalize()` to use.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
### FusedMoEModularKernel
|
### FusedMoEModularKernel
|
||||||
|
|
||||||
`FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` objects.
|
`FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalizeModular` and `FusedMoEExpertsModular` objects.
|
||||||
`FusedMoEModularKernel` pseudocode/sketch,
|
`FusedMoEModularKernel` pseudocode/sketch,
|
||||||
|
|
||||||
```py
|
```py
|
||||||
class FusedMoEModularKernel:
|
class FusedMoEModularKernel:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||||
fused_experts: FusedMoEPermuteExpertsUnpermute):
|
fused_experts: FusedMoEExpertsModular):
|
||||||
|
|
||||||
self.prepare_finalize = prepare_finalize
|
self.prepare_finalize = prepare_finalize
|
||||||
self.fused_experts = fused_experts
|
self.fused_experts = fused_experts
|
||||||
@@ -128,54 +128,53 @@ class FusedMoEModularKernel:
|
|||||||
|
|
||||||
## How-To
|
## How-To
|
||||||
|
|
||||||
### How To Add a FusedMoEPrepareAndFinalize Type
|
### How To Add a FusedMoEPrepareAndFinalizeModular Type
|
||||||
|
|
||||||
Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example,
|
Typically a FusedMoEPrepareAndFinalizeModular type is backed by an All2All Dispatch & Combine implementation / kernel. For example,
|
||||||
|
|
||||||
* PplxPrepareAndFinalize type is backed by Pplx All2All kernels,
|
|
||||||
* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and
|
* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and
|
||||||
* DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels.
|
* DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels.
|
||||||
|
|
||||||
#### Step 1: Add an All2All manager
|
#### Step 1: Add an All2All manager
|
||||||
|
|
||||||
The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py).
|
The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalizeModular` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py).
|
||||||
|
|
||||||
#### Step 2: Add a FusedMoEPrepareAndFinalize Type
|
#### Step 2: Add a FusedMoEPrepareAndFinalizeModular Type
|
||||||
|
|
||||||
This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize` abstract class.
|
This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalizeModular` abstract class.
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
`FusedMoEPrepareAndFinalizeModular::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False.
|
`FusedMoEPrepareAndFinalizeModular::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False.
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
`FusedMoEPrepareAndFinalizeModular::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.
|
`FusedMoEPrepareAndFinalizeModular::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise.
|
`FusedMoEPrepareAndFinalizeModular::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise.
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::topk_indices_dtype()`: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the `FusedMoe::select_experts` function so it could be respected. If there are no strict requirements return None.
|
`FusedMoEPrepareAndFinalizeModular::topk_indices_dtype()`: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the `FusedMoe::select_experts` function so it could be respected. If there are no strict requirements return None.
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::max_num_tokens_per_rank()`: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once.
|
`FusedMoEPrepareAndFinalizeModular::max_num_tokens_per_rank()`: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once.
|
||||||
|
|
||||||
`FusedMoEPrepareAndFinalize::num_dispatchers()`: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank().
|
`FusedMoEPrepareAndFinalizeModular::num_dispatchers()`: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank().
|
||||||
|
|
||||||
We suggest picking an already existing `FusedMoEPrepareAndFinalize` implementation that matches your All2All implementation closely and using it as a reference.
|
We suggest picking an already existing `FusedMoEPrepareAndFinalizeModular` implementation that matches your All2All implementation closely and using it as a reference.
|
||||||
|
|
||||||
### How To Add a FusedMoEPermuteExpertsUnpermute Type
|
### How To Add a FusedMoEExpertsModular Type
|
||||||
|
|
||||||
FusedMoEPermuteExpertsUnpermute performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows,
|
FusedMoEExpertsModular performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows,
|
||||||
|
|
||||||
`FusedMoEPermuteExpertsUnpermute::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
|
`FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
|
||||||
|
|
||||||
`FusedMoEPermuteExpertsUnpermute::supports_chunking()`: Return True if the implementation supports chunking. Typically
|
`FusedMoEExpertsModular::supports_chunking()`: Return True if the implementation supports chunking. Typically
|
||||||
implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not.
|
implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not.
|
||||||
|
|
||||||
`FusedMoEPermuteExpertsUnpermute::supports_expert_map()`: Return True if the implementation supports expert map.
|
`FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map.
|
||||||
|
|
||||||
`FusedMoEPermuteExpertsUnpermute::workspace_shapes()` /
|
`FusedMoEExpertsModular::workspace_shapes()` /
|
||||||
`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` /
|
`FusedMoEExpertsModular::finalize_weight_and_reduce_impl` /
|
||||||
`FusedMoEPermuteExpertsUnpermute::apply`: Refer to `FusedMoEPermuteExpertsUnpermute` section above.
|
`FusedMoEExpertsModular::apply`: Refer to `FusedMoEExpertsModular` section above.
|
||||||
|
|
||||||
### FusedMoEModularKernel Initialization
|
### FusedMoEModularKernel Initialization
|
||||||
|
|
||||||
@@ -187,14 +186,14 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
|
|||||||
|
|
||||||
#### maybe_make_prepare_finalize
|
#### maybe_make_prepare_finalize
|
||||||
|
|
||||||
The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
|
The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalizeModular` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalizeModular` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
|
||||||
Please refer to the implementations in,
|
Please refer to the implementations in,
|
||||||
|
|
||||||
* `ModelOptNvFp4FusedMoE`
|
* `ModelOptNvFp4FusedMoE`
|
||||||
|
|
||||||
#### select_gemm_impl
|
#### select_gemm_impl
|
||||||
|
|
||||||
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.
|
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEExpertsModular` object.
|
||||||
Please refer to the implementations in,
|
Please refer to the implementations in,
|
||||||
|
|
||||||
* `UnquantizedFusedMoEMethod`
|
* `UnquantizedFusedMoEMethod`
|
||||||
@@ -206,7 +205,7 @@ derived classes.
|
|||||||
|
|
||||||
#### init_prepare_finalize
|
#### init_prepare_finalize
|
||||||
|
|
||||||
Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalize` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEPermuteExpertsUnpermute` object and builds the `FusedMoEModularKernel` object
|
Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalizeModular` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEExpertsModular` object and builds the `FusedMoEModularKernel` object
|
||||||
|
|
||||||
Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vllm/blob/1cbf951ba272c230823b947631065b826409fa62/vllm/model_executor/layers/fused_moe/layer.py#L188).
|
Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vllm/blob/1cbf951ba272c230823b947631065b826409fa62/vllm/model_executor/layers/fused_moe/layer.py#L188).
|
||||||
**Important**: The `FusedMoEMethodBase` derived classes use the `FusedMoEMethodBase::fused_experts` object in their `apply` methods. When settings permit the construction of a valid `FusedMoEModularKernel` object, we override `FusedMoEMethodBase::fused_experts` with it. This essentially makes the derived classes agnostic to what FusedMoE implementation is used.
|
**Important**: The `FusedMoEMethodBase` derived classes use the `FusedMoEMethodBase::fused_experts` object in their `apply` methods. When settings permit the construction of a valid `FusedMoEModularKernel` object, we override `FusedMoEMethodBase::fused_experts` with it. This essentially makes the derived classes agnostic to what FusedMoE implementation is used.
|
||||||
@@ -215,9 +214,9 @@ Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vl
|
|||||||
|
|
||||||
We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py).
|
We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py).
|
||||||
|
|
||||||
The unit test iterates through all combinations of `FusedMoEPrepareAndFinalize` and `FusedMoEPremuteExpertsUnpermute` types and if they are
|
The unit test iterates through all combinations of `FusedMoEPrepareAndFinalizeModular` and `FusedMoEPremuteExpertsUnpermute` types and if they are
|
||||||
compatible, runs some correctness tests.
|
compatible, runs some correctness tests.
|
||||||
If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnpermute` implementations,
|
If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsModular` implementations,
|
||||||
|
|
||||||
1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively.
|
1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively.
|
||||||
2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`,
|
2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`,
|
||||||
@@ -226,24 +225,24 @@ If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnp
|
|||||||
|
|
||||||
Doing this will add the new implementation to the test suite.
|
Doing this will add the new implementation to the test suite.
|
||||||
|
|
||||||
### How To Check `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` Compatibility
|
### How To Check `FusedMoEPrepareAndFinalizeModular` & `FusedMoEExpertsModular` Compatibility
|
||||||
|
|
||||||
The unit test file [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script.
|
The unit test file [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script.
|
||||||
Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts`
|
Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts`
|
||||||
As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked
|
As a side effect, this script can be used to test `FusedMoEPrepareAndFinalizeModular` & `FusedMoEExpertsModular` compatibility. When invoked
|
||||||
with incompatible types, the script will error.
|
with incompatible types, the script will error.
|
||||||
|
|
||||||
### How To Profile
|
### How To Profile
|
||||||
|
|
||||||
Please take a look at [profile_modular_kernel.py](../../tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py)
|
Please take a look at [profile_modular_kernel.py](../../tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py)
|
||||||
The script can be used to generate Torch traces for a single `FusedMoEModularKernel::forward()` call for any compatible
|
The script can be used to generate Torch traces for a single `FusedMoEModularKernel::forward()` call for any compatible
|
||||||
`FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` types.
|
`FusedMoEPrepareAndFinalizeModular` and `FusedMoEExpertsModular` types.
|
||||||
Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts`
|
Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts`
|
||||||
|
|
||||||
## FusedMoEPrepareAndFinalize Implementations
|
## FusedMoEPrepareAndFinalizeModular Implementations
|
||||||
|
|
||||||
See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-modular-all2all-backends) for a list of all the available modular prepare and finalize subclasses.
|
See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-modular-all2all-backends) for a list of all the available modular prepare and finalize subclasses.
|
||||||
|
|
||||||
## FusedMoEPermuteExpertsUnpermute
|
## FusedMoEExpertsModular
|
||||||
|
|
||||||
See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-experts-kernels) for a list of all the available modular experts.
|
See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-experts-kernels) for a list of all the available modular experts.
|
||||||
|
|||||||
@@ -13,12 +13,13 @@ IOProcessorInput = TypeVar("IOProcessorInput")
|
|||||||
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
||||||
|
|
||||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
"""Abstract interface for pre/post-processing of engine I/O."""
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def parse_data(self, data: object) -> IOProcessorInput:
|
def parse_data(self, data: object) -> IOProcessorInput:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -32,7 +33,7 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
|||||||
self,
|
self,
|
||||||
params: PoolingParams | None = None,
|
params: PoolingParams | None = None,
|
||||||
) -> PoolingParams:
|
) -> PoolingParams:
|
||||||
return params or PoolingParams()
|
return params or PoolingParams(task="plugin")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def pre_process(
|
def pre_process(
|
||||||
|
|||||||
@@ -656,7 +656,7 @@ vLLM has support for OpenTelemetry tracing:
|
|||||||
- Added by <https://github.com/vllm-project/vllm/pull/4687> and reinstated by <https://github.com/vllm-project/vllm/pull/20372>
|
- Added by <https://github.com/vllm-project/vllm/pull/4687> and reinstated by <https://github.com/vllm-project/vllm/pull/20372>
|
||||||
- Configured with `--oltp-traces-endpoint` and `--collect-detailed-traces`
|
- Configured with `--oltp-traces-endpoint` and `--collect-detailed-traces`
|
||||||
- [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/)
|
- [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/)
|
||||||
- [User-facing docs](../examples/online_serving/opentelemetry.md)
|
- [User-facing docs](../../examples/online_serving/opentelemetry/README.md)
|
||||||
- [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f)
|
- [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f)
|
||||||
- [IBM product docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview)
|
- [IBM product docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview)
|
||||||
|
|
||||||
|
|||||||
198
docs/design/model_runner_v2.md
Normal file
198
docs/design/model_runner_v2.md
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Model Runner V2 Design Document
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
Since vLLM V1 was first implemented, we discovered several fundamental design mistakes and accumulated significant technical debt. Many features were bolted on that were not considered in the original design. We also gained valuable insights into sampling techniques (for example, Gumbel-max sampling), tools (for example, Triton), and CUDA features (for example, UVA). With this knowledge, we implemented Model Runner V2 (MRV2) from first principles to be cleaner, more efficient, and more modular.
|
||||||
|
|
||||||
|
In hindsight, many of V1's design choices were suboptimal. While MRV2 is not yet feature-complete, not rigorously tested, and still has open design decisions, we believe it is a substantial improvement over V1.
|
||||||
|
|
||||||
|
This document describes the design of MRV2.
|
||||||
|
|
||||||
|
## 1. Persistent Batch
|
||||||
|
|
||||||
|
One significant source of friction in V1 is its persistent batch implementation.
|
||||||
|
|
||||||
|
### Background
|
||||||
|
|
||||||
|
V1 introduced persistent batches to minimize CPU overhead during input preparation. When requests are scheduled for a step, the model runner must construct contiguous input tensors (for example, block tables and per-request temperature values) to feed into the model. Building these tensors from scratch each step is often very slow in Python, especially for large tensors like block tables.
|
||||||
|
|
||||||
|
The persistent batch optimization exploits the fact that request batches in consecutive steps are mostly identical. Only a few requests (if any) join or finish per step. By maintaining persistent state tensors and applying incremental diffs instead of reconstructing inputs from scratch, CPU overhead can be reduced significantly.
|
||||||
|
|
||||||
|
### Problems with V1's Approach
|
||||||
|
|
||||||
|
While efficient, V1's persistent batch design introduced unnecessary complexity due to coupling persistent state with input tensors. V1 uses persistent state tensors directly as model and sampler inputs, which imposes strict layout and ordering requirements. When requests join or finish, this often requires complex tensor-wide reordering rather than simple row insertion/removal.
|
||||||
|
|
||||||
|
V1 also had to maintain `CachedRequestState`, a redundant backup copy of request state, because rows in persistent tensors can be overwritten while requests are still active.
|
||||||
|
|
||||||
|
The result is complex bookkeeping that becomes more difficult under async scheduling.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### MRV2's Solution
|
||||||
|
|
||||||
|
MRV2 decouples persistent state tensors from per-step input tensors. Given request ordering for the step (usually determined by the attention backend), MRV2 gathers input tensors from persistent state.
|
||||||
|
|
||||||
|
1. Pre-allocate a fixed-size tensor with `max_num_reqs` rows (1024 by default on most platforms).
|
||||||
|
2. Assign each request a permanent row for its active lifetime (until finish or preemption).
|
||||||
|
3. Treat preemption as completion. On resume, re-add request data as fresh state.
|
||||||
|
|
||||||
|
This removes the need for `CachedRequestState` and simplifies bookkeeping. Large state tensors are mostly stored on GPU memory, so gather runs in parallel on the GPU with low overhead.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## 2. Async-First
|
||||||
|
|
||||||
|
vLLM now relies heavily on asynchronous scheduling. The scheduler and worker prepare inputs for step `N+1` while the GPU executes step `N`, overlapping CPU and GPU work to maximize utilization.
|
||||||
|
|
||||||
|
V1 was not originally designed with async scheduling in mind, and support required retrofitted behavior and hacks. MRV2 instead assumes the core model execution loop is a CUDA stream with no CPU synchronization points. CPU entrypoints queue work onto the stream.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## 3. Removing Async Barrier
|
||||||
|
|
||||||
|
A key requirement for async execution is that CPU operations remain non-blocking. Both explicit sync (for example, `torch.cuda.synchronize`) and implicit sync (for example, unpinned `.to("cuda")`) must be avoided.
|
||||||
|
|
||||||
|
However, async execution can introduce race conditions when CPU and GPU concurrently touch the same memory.
|
||||||
|
|
||||||
|
Example (unsafe):
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ModelRunner:
|
||||||
|
def __init__(self, ...):
|
||||||
|
# Pinned buffer
|
||||||
|
self.states = torch.zeros(
|
||||||
|
max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_step(self, ...):
|
||||||
|
self.states[req_idx] = new_req.data
|
||||||
|
states = self.states.to("cuda", non_blocking=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
The CPU may modify `self.states` while GPU is still reading from it via async copy.
|
||||||
|
|
||||||
|
V1 addresses this with an async barrier around critical sections. That avoids races but has drawbacks:
|
||||||
|
|
||||||
|
1. Easy to miss protected buffers (bug-prone).
|
||||||
|
2. Inflexible organization (all CPU work must stay inside barrier).
|
||||||
|
3. Potentially less overlap due to synchronization.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### MRV2's Solution: Eliminate the Race
|
||||||
|
|
||||||
|
MRV2 separates persistent CPU state from the copied tensor:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ModelRunner:
|
||||||
|
def __init__(self, ...):
|
||||||
|
# Not pinned
|
||||||
|
self.states = torch.zeros(
|
||||||
|
max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_step(self, ...):
|
||||||
|
self.states[req_idx] = new_req.data
|
||||||
|
tmp_states = self.states.pin_memory()
|
||||||
|
states = tmp_states.to("cuda", non_blocking=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Now CPU writes to `self.states` while GPU reads from `tmp_states`, eliminating the race without explicit synchronization.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## 4. StagedWriteTensor
|
||||||
|
|
||||||
|
For large tensors like block tables, MRV2 avoids full CPU-to-GPU copies each step by using `StagedWriteTensor`:
|
||||||
|
|
||||||
|
1. Keep the base tensor on GPU.
|
||||||
|
2. Stage diffs on CPU.
|
||||||
|
3. Pack diffs into contiguous buffers.
|
||||||
|
4. Copy packed diffs to GPU.
|
||||||
|
5. Launch one kernel to apply diffs.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Initialize state on GPU
|
||||||
|
state = StagedWriteTensor(size=(1024, 1000), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
# Write [3, 1, 2] into row 2, starting at index 3
|
||||||
|
state.stage_write(row=2, start=3, value=[3, 1, 2])
|
||||||
|
|
||||||
|
# Write [-1, -2, -5] into row 0, starting at index 1
|
||||||
|
state.stage_write(row=0, start=1, value=[-1, -2, -5])
|
||||||
|
|
||||||
|
# Apply staged changes
|
||||||
|
state.apply_write()
|
||||||
|
```
|
||||||
|
|
||||||
|
This supports ragged updates with no CPU-GPU synchronization and minimal kernel launches. It is especially useful for block tables and mixed CPU/GPU-written states such as `num_computed_tokens`.
|
||||||
|
|
||||||
|
## 5. GPU-Native Input Metadata Preparation and Output Processing
|
||||||
|
|
||||||
|
MRV2 uses Triton kernels to prepare inputs such as `input_ids`, `positions`, `query_start_loc`, and `seq_lens`.
|
||||||
|
|
||||||
|
Benefits:
|
||||||
|
|
||||||
|
1. Better async behavior: GPU can derive values (for example with speculative decoding) that CPU may not know yet.
|
||||||
|
2. Lower CPU overhead: input prep is very cheap on GPU and avoids Python bottlenecks.
|
||||||
|
|
||||||
|
### Universal Virtual Addressing (UVA)
|
||||||
|
|
||||||
|
MRV2 uses UVA in some paths to let GPU kernels access large CPU-resident tensors directly (for example `prefill_token_ids`) without duplicating those tensors into GPU memory.
|
||||||
|
|
||||||
|
## 6. Triton-Native Sampler
|
||||||
|
|
||||||
|
MRV2 reimplements sampling mostly in Triton for better numeric/memory control and optimization.
|
||||||
|
|
||||||
|
### Gumbel Sampling Kernel
|
||||||
|
|
||||||
|
MRV2 introduces a Triton Gumbel sampling kernel that avoids explicit softmax materialization and uses stateless in-kernel RNG from seed input.
|
||||||
|
|
||||||
|
### Efficient Top-K Logprobs
|
||||||
|
|
||||||
|
V1 materializes full-vocabulary logprobs before top-k. MRV2 identifies top-k tokens from logits first, then computes logprobs only for selected tokens. This reduces peak GPU memory usage.
|
||||||
|
|
||||||
|
### Memory-Efficient Prompt Logprobs
|
||||||
|
|
||||||
|
MRV2 supports finer-grained chunking, including chunking inside a single prompt, to avoid memory spikes on long prompts.
|
||||||
|
|
||||||
|
### Better Compatibility with Speculative Decoding
|
||||||
|
|
||||||
|
Instead of expanding per-request sampling states to match per-logit shapes, MRV2 uses indirection (`idx_mapping`) inside kernels to map each logits vector to the right request state. This simplifies support for complex sampling parameters and logits processors.
|
||||||
|
|
||||||
|
## 7. Modularity
|
||||||
|
|
||||||
|
MRV2 emphasizes modularity. Compared to V1's large, entangled `gpu_model_runner.py`, MRV2 splits feature logic across dedicated files (for example, `mrope_utils.py`, `penalties.py`, and many others).
|
||||||
|
|
||||||
|
It also consolidates model inputs into an `InputBatch` class and reduces direct model-runner attribute coupling.
|
||||||
|
|
||||||
|
## 8. No Abuse of `dummy_run`
|
||||||
|
|
||||||
|
In V1, `dummy_run` handled too many responsibilities:
|
||||||
|
|
||||||
|
- Initial memory profiling and `torch.compile`
|
||||||
|
- CUDA graph capture
|
||||||
|
- Warmups
|
||||||
|
- Empty DP forward passes for EP+DP
|
||||||
|
|
||||||
|
MRV2 simplifies this:
|
||||||
|
|
||||||
|
1. `execute_model` supports dummy runs without affecting state.
|
||||||
|
2. `dummy_run` delegates to `execute_model` for profiling, warmup, and empty DP forward passes.
|
||||||
|
3. CUDA graph capture uses a separate dedicated path.
|
||||||
|
|
||||||
|
This reduces complexity and removes bugs caused by divergence between `execute_model` and `dummy_run` behavior.
|
||||||
|
|
||||||
|
## 9. Explicit CUDA Graph Management
|
||||||
|
|
||||||
|
V1's CUDA graph handling is implicit and hard to reason about. MRV2 uses a `CUDAGraphManager` that explicitly captures and launches full CUDA graphs through standard PyTorch APIs.
|
||||||
|
|
||||||
|
This makes graph lifecycle and execution mode decisions more understandable and easier to extend. Example: MRV2 can capture multiple draft-model forward passes into one CUDA graph.
|
||||||
|
|
||||||
|
## Development Philosophy
|
||||||
|
|
||||||
|
MRV2 changes should meet a higher code quality bar. As feature gaps with V1 are filled, features should be reconsidered from first principles in the MRV2 design context instead of quickly porting V1 behavior.
|
||||||
|
|
||||||
|
A key requirement is preserving modularity and clean abstraction boundaries, even if that requires more upfront design iteration.
|
||||||
@@ -4,17 +4,17 @@ The purpose of this document is to provide an overview of the various MoE kernel
|
|||||||
|
|
||||||
## Fused MoE Modular All2All backends
|
## Fused MoE Modular All2All backends
|
||||||
|
|
||||||
There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalize` subclasses provide an interface for each all2all backend.
|
There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalizeModular` subclasses provide an interface for each all2all backend.
|
||||||
|
|
||||||
The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support.
|
The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support.
|
||||||
|
|
||||||
The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalize` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
|
The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalizeModular` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
|
||||||
|
|
||||||
The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalize` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
|
The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalizeModular` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
|
||||||
|
|
||||||
Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step).
|
Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step).
|
||||||
|
|
||||||
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
|
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalizeModular` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
|
||||||
|
|
||||||
Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
|
Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
|
||||||
|
|
||||||
@@ -33,12 +33,9 @@ th {
|
|||||||
| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Subclass |
|
| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Subclass |
|
||||||
|---------|--------------------|--------------|---------------|-------|-----------------------|-----------|
|
|---------|--------------------|--------------|---------------|-------|-----------------------|-----------|
|
||||||
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE] |
|
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE] |
|
||||||
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
|
|
||||||
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
|
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
|
||||||
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
|
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
|
||||||
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
|
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
|
||||||
| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
|
|
||||||
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
|
|
||||||
|
|
||||||
!!! info "Table key"
|
!!! info "Table key"
|
||||||
1. All types: mxfp4, nvfp4, int4, int8, fp8
|
1. All types: mxfp4, nvfp4, int4, int8, fp8
|
||||||
@@ -68,7 +65,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
|
|||||||
|
|
||||||
There are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adapters, so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties.
|
There are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adapters, so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties.
|
||||||
|
|
||||||
Each kernel must be provided with one of the supported input activation formats. Some flavors of kernels support both standard and batched formats through different entry points, e.g. `TritonExperts` and `BatchedTritonExperts`. Batched format kernels are currently only needed for matching with certain all2all backends, e.g. `pplx` and `DeepEPLLPrepareAndFinalize`.
|
Each kernel must be provided with one of the supported input activation formats. Some flavors of kernels support both standard and batched formats through different entry points, e.g. `TritonExperts` and `BatchedTritonExperts`. Batched format kernels are currently only needed for matching with certain all2all backends, e.g. `DeepEPLLPrepareAndFinalize`.
|
||||||
|
|
||||||
Similar to the backend kernels, each experts kernel only supports certain quantization formats. For non-modular experts, the activations will be in the original type and quantized internally by the kernel. Modular experts will expect the activations to already be in the quantized format. Both types of experts will yield outputs in the original activation type.
|
Similar to the backend kernels, each experts kernel only supports certain quantization formats. For non-modular experts, the activations will be in the original type and quantized internally by the kernel. Modular experts will expect the activations to already be in the quantized format. Both types of experts will yield outputs in the original activation type.
|
||||||
|
|
||||||
@@ -76,9 +73,9 @@ Each experts kernel supports one or more activation functions, e.g. silu or gelu
|
|||||||
|
|
||||||
As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts.
|
As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts.
|
||||||
|
|
||||||
Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEPermuteExpertsUnpermute`.
|
Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEExpertsModular`.
|
||||||
|
|
||||||
To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats.
|
To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats.
|
||||||
|
|
||||||
| Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source |
|
| Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source |
|
||||||
|--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------|
|
|--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------|
|
||||||
@@ -107,8 +104,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
|
|||||||
|
|
||||||
The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts.
|
The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts.
|
||||||
|
|
||||||
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|
| backend | `FusedMoEPrepareAndFinalizeModular` subclasses | `FusedMoEExpertsModular` subclasses |
|
||||||
|---------|-----------------------------------------|----------------------------------------------|
|
|---------|-----------------------------------------|----------------------------------------------|
|
||||||
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
|
||||||
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
|
| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
|
||||||
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ Sorted alphabetically by GitHub handle:
|
|||||||
- [@ywang96](https://github.com/ywang96): Multimodality, benchmarks
|
- [@ywang96](https://github.com/ywang96): Multimodality, benchmarks
|
||||||
- [@zhuohan123](https://github.com/zhuohan123): Project lead, RL integration, numerics
|
- [@zhuohan123](https://github.com/zhuohan123): Project lead, RL integration, numerics
|
||||||
- [@zou3519](https://github.com/zou3519): Compilation
|
- [@zou3519](https://github.com/zou3519): Compilation
|
||||||
|
- [@BoyuanFeng](https://github.com/BoyuanFeng): Compilation, CUDAGraph
|
||||||
|
|
||||||
### Emeritus Committers
|
### Emeritus Committers
|
||||||
|
|
||||||
@@ -113,7 +114,7 @@ If you have PRs touching the area, please feel free to ping the area owner for r
|
|||||||
- Multi-modal Input Processing: Components that load and process image/video/audio data into feature tensors
|
- Multi-modal Input Processing: Components that load and process image/video/audio data into feature tensors
|
||||||
- @DarkLight1337, @ywang96, @Isotr0py
|
- @DarkLight1337, @ywang96, @Isotr0py
|
||||||
- torch compile: The torch.compile integration in vLLM, custom passes & transformations
|
- torch compile: The torch.compile integration in vLLM, custom passes & transformations
|
||||||
- @ProExpertProg, @zou3519, @youkaichao
|
- @ProExpertProg, @zou3519, @youkaichao, @BoyuanFeng
|
||||||
- State space models: The state space models implementation in vLLM
|
- State space models: The state space models implementation in vLLM
|
||||||
- @tdoublep, @tlrmchlsmth
|
- @tdoublep, @tlrmchlsmth
|
||||||
- Reasoning and tool calling parsers
|
- Reasoning and tool calling parsers
|
||||||
@@ -154,7 +155,7 @@ If you have PRs touching the area, please feel free to ping the area owner for r
|
|||||||
- FlashAttention: @LucasWilkinson
|
- FlashAttention: @LucasWilkinson
|
||||||
- FlashInfer: @LucasWilkinson, @mgoin, @WoosukKwon
|
- FlashInfer: @LucasWilkinson, @mgoin, @WoosukKwon
|
||||||
- Blackwell Kernels: @mgoin, @yewentao256
|
- Blackwell Kernels: @mgoin, @yewentao256
|
||||||
- DeepEP/DeepGEMM/pplx: @mgoin, @yewentao256
|
- DeepEP/DeepGEMM: @mgoin, @yewentao256
|
||||||
|
|
||||||
### Integrations
|
### Integrations
|
||||||
|
|
||||||
|
|||||||
@@ -100,8 +100,8 @@ bench_sweep_plot_pareto = auto_mock(
|
|||||||
"vllm.benchmarks.sweep.plot_pareto", "SweepPlotParetoArgs"
|
"vllm.benchmarks.sweep.plot_pareto", "SweepPlotParetoArgs"
|
||||||
)
|
)
|
||||||
bench_sweep_serve = auto_mock("vllm.benchmarks.sweep.serve", "SweepServeArgs")
|
bench_sweep_serve = auto_mock("vllm.benchmarks.sweep.serve", "SweepServeArgs")
|
||||||
bench_sweep_serve_sla = auto_mock(
|
bench_sweep_serve_workload = auto_mock(
|
||||||
"vllm.benchmarks.sweep.serve_sla", "SweepServeSLAArgs"
|
"vllm.benchmarks.sweep.serve_workload", "SweepServeWorkloadArgs"
|
||||||
)
|
)
|
||||||
bench_throughput = auto_mock("vllm.benchmarks", "throughput")
|
bench_throughput = auto_mock("vllm.benchmarks", "throughput")
|
||||||
AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs")
|
AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs")
|
||||||
@@ -229,7 +229,9 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
|||||||
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
|
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
|
||||||
"bench_sweep_plot_pareto": create_parser(bench_sweep_plot_pareto.add_cli_args),
|
"bench_sweep_plot_pareto": create_parser(bench_sweep_plot_pareto.add_cli_args),
|
||||||
"bench_sweep_serve": create_parser(bench_sweep_serve.add_cli_args),
|
"bench_sweep_serve": create_parser(bench_sweep_serve.add_cli_args),
|
||||||
"bench_sweep_serve_sla": create_parser(bench_sweep_serve_sla.add_cli_args),
|
"bench_sweep_serve_workload": create_parser(
|
||||||
|
bench_sweep_serve_workload.add_cli_args
|
||||||
|
),
|
||||||
"bench_throughput": create_parser(bench_throughput.add_cli_args),
|
"bench_throughput": create_parser(bench_throughput.add_cli_args),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -498,6 +498,133 @@ curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
|
|||||||
- Multi-vector retrieval: [examples/pooling/token_embed/colqwen3_token_embed_online.py](../../examples/pooling/token_embed/colqwen3_token_embed_online.py)
|
- Multi-vector retrieval: [examples/pooling/token_embed/colqwen3_token_embed_online.py](../../examples/pooling/token_embed/colqwen3_token_embed_online.py)
|
||||||
- Reranking (text + multi-modal): [examples/pooling/score/colqwen3_rerank_online.py](../../examples/pooling/score/colqwen3_rerank_online.py)
|
- Reranking (text + multi-modal): [examples/pooling/score/colqwen3_rerank_online.py](../../examples/pooling/score/colqwen3_rerank_online.py)
|
||||||
|
|
||||||
|
### Llama Nemotron Multimodal
|
||||||
|
|
||||||
|
#### Embedding Model
|
||||||
|
|
||||||
|
Llama Nemotron VL Embedding models combine the bidirectional Llama embedding backbone
|
||||||
|
(from `nvidia/llama-nemotron-embed-1b-v2`) with SigLIP as the vision encoder to produce
|
||||||
|
single-vector embeddings from text and/or images.
|
||||||
|
|
||||||
|
| Architecture | Backbone | Example HF Models |
|
||||||
|
|---|---|---|
|
||||||
|
| `LlamaNemotronVLModel` | Bidirectional Llama + SigLIP | `nvidia/llama-nemotron-embed-vl-1b-v2` |
|
||||||
|
|
||||||
|
Start the server:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
vllm serve nvidia/llama-nemotron-embed-vl-1b-v2 \
|
||||||
|
--trust-remote-code \
|
||||||
|
--chat-template examples/pooling/embed/template/nemotron_embed_vl.jinja
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
The chat template bundled with this model's tokenizer is not suitable for
|
||||||
|
the embeddings API. Use the provided override template above when serving
|
||||||
|
with the `messages`-based (chat-style) embeddings endpoint.
|
||||||
|
|
||||||
|
The override template uses the message `role` to automatically prepend the
|
||||||
|
appropriate prefix: set `role` to `"query"` for queries (prepends `query: `)
|
||||||
|
or `"document"` for passages (prepends `passage: `). Any other role omits
|
||||||
|
the prefix.
|
||||||
|
|
||||||
|
Embed text queries:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -s http://localhost:8000/v1/embeddings -H "Content-Type: application/json" -d '{
|
||||||
|
"model": "nvidia/llama-nemotron-embed-vl-1b-v2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "query",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is machine learning?"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Embed images via the chat-style `messages` field:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -s http://localhost:8000/v1/embeddings -H "Content-Type: application/json" -d '{
|
||||||
|
"model": "nvidia/llama-nemotron-embed-vl-1b-v2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "document",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64>"}},
|
||||||
|
{"type": "text", "text": "Describe the image."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Reranker Model
|
||||||
|
|
||||||
|
Llama Nemotron VL reranker models combine the same bidirectional Llama + SigLIP
|
||||||
|
backbone with a sequence-classification head for cross-encoder scoring and reranking.
|
||||||
|
|
||||||
|
| Architecture | Backbone | Example HF Models |
|
||||||
|
|---|---|---|
|
||||||
|
| `LlamaNemotronVLForSequenceClassification` | Bidirectional Llama + SigLIP | `nvidia/llama-nemotron-rerank-vl-1b-v2` |
|
||||||
|
|
||||||
|
Start the server:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
vllm serve nvidia/llama-nemotron-rerank-vl-1b-v2 \
|
||||||
|
--runner pooling \
|
||||||
|
--trust-remote-code \
|
||||||
|
--chat-template examples/pooling/score/template/nemotron-vl-rerank.jinja
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
The chat template bundled with this checkpoint's tokenizer is not suitable
|
||||||
|
for the Score/Rerank APIs. Use the provided override template when serving:
|
||||||
|
`examples/pooling/score/template/nemotron-vl-rerank.jinja`.
|
||||||
|
|
||||||
|
Score a text query against an image document:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
|
||||||
|
"model": "nvidia/llama-nemotron-rerank-vl-1b-v2",
|
||||||
|
"data_1": "Find diagrams about autonomous robots",
|
||||||
|
"data_2": [
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64>"}},
|
||||||
|
{"type": "text", "text": "Robotics workflow diagram."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Rerank image documents by a text query:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
|
||||||
|
"model": "nvidia/llama-nemotron-rerank-vl-1b-v2",
|
||||||
|
"query": "Find diagrams about autonomous robots",
|
||||||
|
"documents": [
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64_1>"}},
|
||||||
|
{"type": "text", "text": "Robotics workflow diagram."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64_2>"}},
|
||||||
|
{"type": "text", "text": "General skyline photo."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_n": 2
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
### BAAI/bge-m3
|
### BAAI/bge-m3
|
||||||
|
|
||||||
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`
|
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`
|
||||||
|
|||||||
@@ -369,9 +369,11 @@ th {
|
|||||||
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ |
|
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ |
|
||||||
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ |
|
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ |
|
||||||
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ |
|
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ |
|
||||||
|
| `AXK1ForCausalLM` | A.X-K1 | `skt/A.X-K1`, etc. | | ✅︎ |
|
||||||
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ |
|
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ |
|
||||||
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ |
|
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ |
|
||||||
| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ |
|
| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ |
|
||||||
|
| `BailingMoeV2_5ForCausalLM` | Ling | `inclusionAI/Ling-2.5-1T`, `inclusionAI/Ring-2.5-1T` | | ✅︎ |
|
||||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ |
|
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ |
|
||||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ |
|
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ |
|
||||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `thu-coai/ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ |
|
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `thu-coai/ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ |
|
||||||
@@ -791,6 +793,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
|
|||||||
|
|
||||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||||
|--------------|--------|-------------------|----------------------|---------------------------|
|
|--------------|--------|-------------------|----------------------|---------------------------|
|
||||||
|
| `FireRedASR2ForConditionalGeneration` | FireRedASR2 | `allendou/FireRedASR2-LLM-vllm`, etc. | | |
|
||||||
| `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | |
|
| `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | |
|
||||||
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||||
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
|
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
|
||||||
@@ -820,6 +823,7 @@ The following table lists those that are tested in vLLM.
|
|||||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
||||||
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
|
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
|
||||||
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
|
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
|
||||||
|
| `LlamaNemotronVLModel` | Llama Nemotron Embedding + SigLIP | T + I | `nvidia/llama-nemotron-embed-vl-1b-v2` | | |
|
||||||
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
|
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
|
||||||
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |
|
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |
|
||||||
| `Qwen3VLForConditionalGeneration`<sup>C</sup> | Qwen3-VL | T + I + V | `Qwen/Qwen3-VL-Embedding-2B`, etc. | ✅︎ | ✅︎ |
|
| `Qwen3VLForConditionalGeneration`<sup>C</sup> | Qwen3-VL | T + I + V | `Qwen/Qwen3-VL-Embedding-2B`, etc. | ✅︎ | ✅︎ |
|
||||||
@@ -839,6 +843,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
|||||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
||||||
| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | ✅︎ | ✅︎ |
|
| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | ✅︎ | ✅︎ |
|
||||||
|
| `LlamaNemotronVLForSequenceClassification` | Llama Nemotron Reranker + SigLIP | T + I<sup>E+</sup> | `nvidia/llama-nemotron-rerank-vl-1b-v2` | | |
|
||||||
| `Qwen3VLForSequenceClassification` | Qwen3-VL-Reranker | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-Reranker-2B`(see note), etc. | ✅︎ | ✅︎ |
|
| `Qwen3VLForSequenceClassification` | Qwen3-VL-Reranker | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-Reranker-2B`(see note), etc. | ✅︎ | ✅︎ |
|
||||||
|
|
||||||
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
|
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ EP is typically coupled with Data Parallelism (DP). While DP can be used indepen
|
|||||||
|
|
||||||
Before using EP, you need to install the necessary dependencies. We are actively working on making this easier in the future:
|
Before using EP, you need to install the necessary dependencies. We are actively working on making this easier in the future:
|
||||||
|
|
||||||
1. **Install DeepEP and pplx-kernels**: Set up host environment following vLLM's guide for EP kernels [here](../../tools/ep_kernels).
|
1. **Install DeepEP**: Set up host environment following vLLM's guide for EP kernels [here](../../tools/ep_kernels).
|
||||||
2. **Install DeepGEMM library**: Follow the [official instructions](https://github.com/deepseek-ai/DeepGEMM#installation).
|
2. **Install DeepGEMM library**: Follow the [official instructions](https://github.com/deepseek-ai/DeepGEMM#installation).
|
||||||
3. **For disaggregated serving**: Install `gdrcopy` by running the [`install_gdrcopy.sh`](../../tools/install_gdrcopy.sh) script (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/).
|
3. **For disaggregated serving**: Install `gdrcopy` by running the [`install_gdrcopy.sh`](../../tools/install_gdrcopy.sh) script (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/).
|
||||||
|
|
||||||
@@ -19,7 +19,6 @@ vLLM provides multiple communication backends for EP. Use `--all2all-backend` to
|
|||||||
| Backend | Use Case | Features | Best For |
|
| Backend | Use Case | Features | Best For |
|
||||||
|---------|----------|----------|----------|
|
|---------|----------|----------|----------|
|
||||||
| `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration |
|
| `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration |
|
||||||
| `pplx` | Single node | Chunked prefill support, efficient intra-node communication | Single-node deployments, development |
|
|
||||||
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios |
|
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios |
|
||||||
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios |
|
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios |
|
||||||
| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes |
|
| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes |
|
||||||
@@ -71,12 +70,11 @@ For example, with `TP=2, DP=4` (8 GPUs total):
|
|||||||
The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
|
The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Single node EP deployment with pplx backend
|
# Single node EP deployment
|
||||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||||
--tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU
|
--tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU
|
||||||
--data-parallel-size 8 \ # Data parallelism across 8 processes
|
--data-parallel-size 8 \ # Data parallelism across 8 processes
|
||||||
--enable-expert-parallel \ # Enable expert parallelism
|
--enable-expert-parallel # Enable expert parallelism
|
||||||
--all2all-backend pplx # Use pplx communication backend
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Multi-Node Deployment
|
## Multi-Node Deployment
|
||||||
@@ -197,7 +195,6 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
|||||||
--tensor-parallel-size 1 \ # Tensor parallelism
|
--tensor-parallel-size 1 \ # Tensor parallelism
|
||||||
--data-parallel-size 8 \ # Data parallelism
|
--data-parallel-size 8 \ # Data parallelism
|
||||||
--enable-expert-parallel \ # Enable EP
|
--enable-expert-parallel \ # Enable EP
|
||||||
--all2all-backend pplx \ # Use pplx communication backend
|
|
||||||
--enable-eplb \ # Enable load balancer
|
--enable-eplb \ # Enable load balancer
|
||||||
--eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}'
|
--eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}'
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ In order for the language model to support chat protocol, vLLM requires the mode
|
|||||||
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
|
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
|
||||||
specifies how roles, messages, and other chat-specific tokens are encoded in the input.
|
specifies how roles, messages, and other chat-specific tokens are encoded in the input.
|
||||||
|
|
||||||
An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models)
|
An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/#prompt-template-for-meta-llama-3)
|
||||||
|
|
||||||
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those models,
|
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those models,
|
||||||
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
|
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
|
||||||
|
|||||||
58
examples/offline_inference/extract_hidden_states.py
Normal file
58
examples/offline_inference/extract_hidden_states.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Example: Using the custom "extract_hidden_states" speculator method and
|
||||||
|
# ExampleHiddenStatesConnector to extract and save hidden states from vllm
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
llm = LLM(
|
||||||
|
model="Qwen/Qwen3-8B", # Your target model
|
||||||
|
speculative_config={
|
||||||
|
"method": "extract_hidden_states",
|
||||||
|
"num_speculative_tokens": 1,
|
||||||
|
"draft_model_config": {
|
||||||
|
"hf_config": {
|
||||||
|
"eagle_aux_hidden_state_layer_ids": [ # Target model layer indices
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
kv_transfer_config={
|
||||||
|
"kv_connector": "ExampleHiddenStatesConnector",
|
||||||
|
"kv_role": "kv_producer",
|
||||||
|
"kv_connector_extra_config": {
|
||||||
|
"shared_storage_path": tmpdirname,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = ["Generate a sentence with hidden states", "Write a python function"]
|
||||||
|
sampling_params = SamplingParams(max_tokens=1)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
print("\nPrompt:", output.prompt)
|
||||||
|
print("Prompt token ids:", output.prompt_token_ids)
|
||||||
|
|
||||||
|
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
|
||||||
|
assert hidden_states_path is not None
|
||||||
|
print("Prompt hidden states path:", hidden_states_path)
|
||||||
|
|
||||||
|
with safe_open(hidden_states_path, "pt") as f:
|
||||||
|
token_ids = f.get_tensor("token_ids")
|
||||||
|
hidden_states = f.get_tensor("hidden_states")
|
||||||
|
|
||||||
|
print("Extracted token ids:", token_ids) # Matches prompt token ids
|
||||||
|
print(
|
||||||
|
"Extracted hidden states shape:", hidden_states.shape
|
||||||
|
) # [num_hidden_layers, prompt len, hidden size]
|
||||||
|
print("Extracted hidden states:", hidden_states)
|
||||||
@@ -42,16 +42,19 @@ from vllm.distributed.weight_transfer.base import (
|
|||||||
WeightTransferUpdateRequest,
|
WeightTransferUpdateRequest,
|
||||||
)
|
)
|
||||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||||
|
NCCLTrainerSendWeightsArgs,
|
||||||
NCCLWeightTransferEngine,
|
NCCLWeightTransferEngine,
|
||||||
NCCLWeightTransferInitInfo,
|
NCCLWeightTransferInitInfo,
|
||||||
NCCLWeightTransferUpdateInfo,
|
NCCLWeightTransferUpdateInfo,
|
||||||
)
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.network_utils import get_ip, get_open_port
|
from vllm.utils.network_utils import get_ip, get_open_port
|
||||||
from vllm.v1.executor import Executor
|
from vllm.v1.executor import Executor
|
||||||
|
|
||||||
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
|
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
|
||||||
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
|
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
|
||||||
PAUSE_TOKEN_THRESHOLD = 10
|
PAUSE_TOKEN_THRESHOLD = 10
|
||||||
|
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
|
||||||
|
|
||||||
|
|
||||||
class MyLLM(vllm.AsyncLLMEngine):
|
class MyLLM(vllm.AsyncLLMEngine):
|
||||||
@@ -103,7 +106,7 @@ class MyLLM(vllm.AsyncLLMEngine):
|
|||||||
while not self._request_pause_flag:
|
while not self._request_pause_flag:
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await super().pause_generation(mode="keep")
|
await super().pause_generation(mode="keep")
|
||||||
await asyncio.sleep(0.2)
|
await asyncio.sleep(5)
|
||||||
self._generation_paused = True
|
self._generation_paused = True
|
||||||
|
|
||||||
|
|
||||||
@@ -115,10 +118,16 @@ class TrainModel:
|
|||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
init_batch_invariance,
|
init_batch_invariance,
|
||||||
)
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
# need to init all env vars for batch invariance which affect nccl ops
|
# need to init all env vars for batch invariance which affect nccl ops
|
||||||
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)
|
attn_backend = (
|
||||||
|
AttentionBackendEnum.TRITON_ATTN
|
||||||
|
if current_platform.is_rocm()
|
||||||
|
else AttentionBackendEnum.FLASH_ATTN
|
||||||
|
)
|
||||||
|
init_batch_invariance(attn_backend)
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name, dtype=torch.bfloat16
|
model_name, dtype=torch.bfloat16
|
||||||
@@ -152,11 +161,14 @@ class TrainModel:
|
|||||||
|
|
||||||
def broadcast_weights(self, packed: bool = True):
|
def broadcast_weights(self, packed: bool = True):
|
||||||
"""Broadcast weights to the inference engine."""
|
"""Broadcast weights to the inference engine."""
|
||||||
NCCLWeightTransferEngine.trainer_send_weights(
|
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||||
iterator=self.model.named_parameters(),
|
|
||||||
group=self.model_update_group,
|
group=self.model_update_group,
|
||||||
packed=packed,
|
packed=packed,
|
||||||
)
|
)
|
||||||
|
NCCLWeightTransferEngine.trainer_send_weights(
|
||||||
|
iterator=self.model.named_parameters(),
|
||||||
|
trainer_args=trainer_args,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
|
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
|
||||||
@@ -171,23 +183,48 @@ class TrainModel:
|
|||||||
return new_token_ids
|
return new_token_ids
|
||||||
|
|
||||||
|
|
||||||
ray.init(
|
# Build platform-specific env vars for Ray
|
||||||
runtime_env={
|
ray_env_vars = {
|
||||||
"env_vars": {
|
# Prevent Ray from setting CUDA_VISIBLE_DEVICES
|
||||||
# enable batch invariance for deterministic outputs
|
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
|
||||||
"VLLM_BATCH_INVARIANT": "1",
|
}
|
||||||
# prevent ray from setting CUDA_VISIBLE_DEVICES
|
|
||||||
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
|
if current_platform.is_rocm():
|
||||||
}
|
# For ROCm, BATCH_INVARIANT vllm is not supported
|
||||||
}
|
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
|
||||||
)
|
else:
|
||||||
|
# Enable batch invariance for deterministic outputs on NVIDIA
|
||||||
|
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
|
||||||
|
|
||||||
|
ray.init(runtime_env={"env_vars": ray_env_vars})
|
||||||
|
|
||||||
# Launch the training model actor. Ray's resource scheduler will allocate
|
# Launch the training model actor. Ray's resource scheduler will allocate
|
||||||
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
|
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
|
||||||
train_model = TrainModel.remote(MODEL_NAME_V2)
|
train_model = TrainModel.remote(MODEL_NAME_V2)
|
||||||
|
|
||||||
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
rocm_determinism_kwargs = {}
|
||||||
# start-up latency.
|
if current_platform.is_rocm():
|
||||||
|
# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
|
||||||
|
# sequential request processing (max_num_seqs=1).
|
||||||
|
rocm_determinism_kwargs = {
|
||||||
|
"seed": 0,
|
||||||
|
"enable_prefix_caching": False,
|
||||||
|
"max_num_seqs": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build platform-specific LLM kwargs
|
||||||
|
llm_kwargs = dict(
|
||||||
|
model=MODEL_NAME_V1,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_model_len=8192,
|
||||||
|
distributed_executor_backend="ray",
|
||||||
|
attention_backend=ATTN_BACKEND,
|
||||||
|
gpu_memory_utilization=0.75,
|
||||||
|
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
||||||
|
)
|
||||||
|
llm_kwargs.update(rocm_determinism_kwargs)
|
||||||
|
|
||||||
|
# Launch the vLLM inference engine.
|
||||||
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
|
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
|
||||||
# its own placement groups internally for each DP rank, so we must NOT
|
# its own placement groups internally for each DP rank, so we must NOT
|
||||||
# create an outer placement group (it would reserve GPUs and hide them
|
# create an outer placement group (it would reserve GPUs and hide them
|
||||||
@@ -195,15 +232,7 @@ train_model = TrainModel.remote(MODEL_NAME_V2)
|
|||||||
llm = ray.remote(
|
llm = ray.remote(
|
||||||
num_cpus=0,
|
num_cpus=0,
|
||||||
num_gpus=0,
|
num_gpus=0,
|
||||||
)(MyLLM).remote(
|
)(MyLLM).remote(**llm_kwargs)
|
||||||
model=MODEL_NAME_V1,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=8192,
|
|
||||||
distributed_executor_backend="ray",
|
|
||||||
attention_backend="FLASH_ATTN",
|
|
||||||
gpu_memory_utilization=0.75,
|
|
||||||
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
|
||||||
)
|
|
||||||
|
|
||||||
PROMPTS = [
|
PROMPTS = [
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
@@ -300,25 +329,42 @@ for i, (output, pause_idx) in enumerate(results):
|
|||||||
print(f" New weights ({n_after} tokens): {after_text!r}")
|
print(f" New weights ({n_after} tokens): {after_text!r}")
|
||||||
|
|
||||||
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
|
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
|
||||||
|
# This validation relies on batch-invariant (deterministic) generation to
|
||||||
|
# compare outputs from the weight-synced engine against a fresh V2 instance.
|
||||||
|
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
|
||||||
|
# token match. On ROCm, batch invariance is not yet fully implemented
|
||||||
|
# (see https://github.com/vllm-project/vllm/issues/27433 and
|
||||||
|
# https://github.com/vllm-project/vllm/issues/33123), so residual
|
||||||
|
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
|
||||||
|
# can cause single-token divergences that don't indicate a weight-sync
|
||||||
|
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
|
||||||
|
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
|
||||||
|
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
|
||||||
|
|
||||||
print(f"\n{'=' * 50}")
|
print(f"\n{'=' * 50}")
|
||||||
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
|
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
|
||||||
print(f"{'=' * 50}")
|
print(f"{'=' * 50}")
|
||||||
|
|
||||||
ray.get(llm.shutdown.remote())
|
ray.get(llm.shutdown.remote())
|
||||||
ray.kill(llm)
|
ray.kill(llm)
|
||||||
ray.kill(train_model)
|
ray.kill(train_model)
|
||||||
|
|
||||||
llm_v2 = ray.remote(
|
llm_v2_kwargs = dict(
|
||||||
num_cpus=0,
|
|
||||||
num_gpus=0,
|
|
||||||
)(MyLLM).remote(
|
|
||||||
model=MODEL_NAME_V2,
|
model=MODEL_NAME_V2,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
gpu_memory_utilization=0.75,
|
gpu_memory_utilization=0.75,
|
||||||
distributed_executor_backend="ray",
|
distributed_executor_backend="ray",
|
||||||
attention_backend="FLASH_ATTN",
|
attention_backend=ATTN_BACKEND,
|
||||||
)
|
)
|
||||||
|
llm_v2_kwargs.update(rocm_determinism_kwargs)
|
||||||
|
|
||||||
|
llm_v2 = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=0,
|
||||||
|
)(MyLLM).remote(**llm_v2_kwargs)
|
||||||
|
|
||||||
val_futures = [
|
val_futures = [
|
||||||
llm_v2.do_generate.remote(
|
llm_v2.do_generate.remote(
|
||||||
@@ -331,16 +377,17 @@ val_futures = [
|
|||||||
]
|
]
|
||||||
val_results = ray.get(val_futures)
|
val_results = ray.get(val_futures)
|
||||||
|
|
||||||
all_pass = True
|
num_pass = 0
|
||||||
|
num_total = len(results)
|
||||||
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
|
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
|
||||||
expected = list(output.outputs[0].token_ids)[pause_idx:]
|
expected = list(output.outputs[0].token_ids)[pause_idx:]
|
||||||
actual = list(val_output.outputs[0].token_ids)
|
actual = list(val_output.outputs[0].token_ids)
|
||||||
match = actual == expected
|
match = actual == expected
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
|
num_pass += 1
|
||||||
print(f" [PASS] {PROMPTS[i]!r}")
|
print(f" [PASS] {PROMPTS[i]!r}")
|
||||||
else:
|
else:
|
||||||
all_pass = False
|
|
||||||
print(f" [FAIL] {PROMPTS[i]!r}")
|
print(f" [FAIL] {PROMPTS[i]!r}")
|
||||||
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
|
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
|
||||||
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
|
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
|
||||||
@@ -355,5 +402,14 @@ for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_resu
|
|||||||
|
|
||||||
ray.get(llm_v2.shutdown.remote())
|
ray.get(llm_v2.shutdown.remote())
|
||||||
ray.kill(llm_v2)
|
ray.kill(llm_v2)
|
||||||
assert all_pass, "Some prompts failed validation, see above for details"
|
|
||||||
|
pass_rate = num_pass / num_total
|
||||||
|
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
|
||||||
|
print(f" Required: >= {MIN_PASS_RATE:.0%}")
|
||||||
|
|
||||||
|
assert pass_rate >= MIN_PASS_RATE, (
|
||||||
|
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
|
||||||
|
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
|
||||||
|
f"See failures above for details."
|
||||||
|
)
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|||||||
149
examples/offline_inference/new_weight_syncing/rlhf_ipc.py
Normal file
149
examples/offline_inference/new_weight_syncing/rlhf_ipc.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray,
|
||||||
|
with IPC-based weight syncing APIs
|
||||||
|
|
||||||
|
The script colocates the training and inference workloads onto the same GPU using Ray.
|
||||||
|
|
||||||
|
The example performs the following steps:
|
||||||
|
|
||||||
|
* Request a placement group of 1 GPU.
|
||||||
|
* Place the inference model on the above GPU using the placement group.
|
||||||
|
* Place and load the training model on the same GPU using the placement group.
|
||||||
|
* Generate text from a list of prompts using the inference engine.
|
||||||
|
* Update the weights of the training model and broadcast the updated weights
|
||||||
|
to the inference engine by using CUDA IPC handles. Note that
|
||||||
|
for demonstration purposes we simply zero out the weights.
|
||||||
|
|
||||||
|
This example assumes a single-node cluster with a single GPU,
|
||||||
|
but can be extended to multiple GPUs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.util.placement_group import placement_group
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.config import WeightTransferConfig
|
||||||
|
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||||
|
IPCTrainerSendWeightsArgs,
|
||||||
|
IPCWeightTransferEngine,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MyLLM(LLM):
|
||||||
|
"""Configure the vLLM worker for Ray placement group execution."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||||
|
# so that vLLM can manage its own device placement within the worker.
|
||||||
|
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||||
|
# Each worker uses 0.4 GPU so that two instances fit on the same GPU.
|
||||||
|
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
|
||||||
|
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
|
||||||
|
# needed for ipc handle serialization
|
||||||
|
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||||
|
|
||||||
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class TrainModel:
|
||||||
|
def __init__(self, llm_handle: ray.actor.ActorHandle):
|
||||||
|
self.train_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
MODEL_NAME,
|
||||||
|
)
|
||||||
|
self.train_model.to("cuda:0")
|
||||||
|
self.llm_handle = llm_handle
|
||||||
|
|
||||||
|
def init_weight_transfer(self):
|
||||||
|
# IPC backend doesn't need initialization info
|
||||||
|
ray.get(
|
||||||
|
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
|
||||||
|
)
|
||||||
|
|
||||||
|
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
|
||||||
|
"""Broadcast weights to the inference engine using IPC."""
|
||||||
|
self.llm_handle = llm_handle
|
||||||
|
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
|
||||||
|
IPCWeightTransferEngine.trainer_send_weights(
|
||||||
|
iterator=self.train_model.named_parameters(),
|
||||||
|
trainer_args=trainer_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
pg_colocate = placement_group([{"GPU": 1, "CPU": 0}])
|
||||||
|
ray.get(pg_colocate.ready())
|
||||||
|
|
||||||
|
|
||||||
|
llm = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=0,
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=pg_colocate,
|
||||||
|
placement_group_capture_child_tasks=True,
|
||||||
|
),
|
||||||
|
)(MyLLM).remote(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
distributed_executor_backend="ray",
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
weight_transfer_config=WeightTransferConfig(backend="ipc"),
|
||||||
|
load_format="dummy",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_model = TrainModel.options(
|
||||||
|
num_gpus=0.1,
|
||||||
|
num_cpus=0,
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=pg_colocate, placement_group_capture_child_tasks=True
|
||||||
|
),
|
||||||
|
).remote(llm)
|
||||||
|
|
||||||
|
|
||||||
|
# Generate text from the prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0)
|
||||||
|
|
||||||
|
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
ray.get(llm.sleep.remote(level=0))
|
||||||
|
|
||||||
|
ray.get(train_model.init_weight_transfer.remote())
|
||||||
|
# Synchronize the updated weights to the inference engine using batched API.
|
||||||
|
ray.get(train_model.broadcast_weights.remote(llm))
|
||||||
|
|
||||||
|
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||||
|
|
||||||
|
# Generate text with the updated model.
|
||||||
|
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||||
|
print("-" * 50)
|
||||||
|
for output in outputs_updated:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
@@ -36,6 +36,7 @@ from transformers import AutoModelForCausalLM
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import WeightTransferConfig
|
from vllm.config import WeightTransferConfig
|
||||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||||
|
NCCLTrainerSendWeightsArgs,
|
||||||
NCCLWeightTransferEngine,
|
NCCLWeightTransferEngine,
|
||||||
)
|
)
|
||||||
from vllm.utils.network_utils import get_ip, get_open_port
|
from vllm.utils.network_utils import get_ip, get_open_port
|
||||||
@@ -90,11 +91,14 @@ class TrainModel:
|
|||||||
|
|
||||||
def broadcast_weights(self, packed: bool = True):
|
def broadcast_weights(self, packed: bool = True):
|
||||||
"""Broadcast weights to the inference engine."""
|
"""Broadcast weights to the inference engine."""
|
||||||
NCCLWeightTransferEngine.trainer_send_weights(
|
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||||
iterator=self.model.named_parameters(),
|
|
||||||
group=self.model_update_group,
|
group=self.model_update_group,
|
||||||
packed=packed,
|
packed=packed,
|
||||||
)
|
)
|
||||||
|
NCCLWeightTransferEngine.trainer_send_weights(
|
||||||
|
iterator=self.model.named_parameters(),
|
||||||
|
trainer_args=trainer_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Initialize Ray and set the visible devices. The vLLM engine will
|
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||||
@@ -156,6 +160,8 @@ for output in outputs:
|
|||||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
ray.get(llm.sleep.remote(level=0))
|
||||||
|
|
||||||
# Set up the communication channel between the training process and the
|
# Set up the communication channel between the training process and the
|
||||||
# inference engine.
|
# inference engine.
|
||||||
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
|
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
|
||||||
@@ -197,6 +203,8 @@ inference_handle = llm.update_weights.remote(
|
|||||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||||
ray.get([train_handle, inference_handle])
|
ray.get([train_handle, inference_handle])
|
||||||
|
|
||||||
|
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||||
|
|
||||||
# Generate text with the updated model. The output is expected to be normal
|
# Generate text with the updated model. The output is expected to be normal
|
||||||
# because the weights are updated.
|
# because the weights are updated.
|
||||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||||
@@ -64,7 +64,7 @@ vllm serve "$MODEL_NAME" \
|
|||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--enable-expert-parallel \
|
--enable-expert-parallel \
|
||||||
--enable-eplb \
|
--enable-eplb \
|
||||||
--all2all-backend pplx \
|
--all2all-backend allgather_reducescatter \
|
||||||
--num-redundant-experts "$REDUNDANT_EXPERTS" \
|
--num-redundant-experts "$REDUNDANT_EXPERTS" \
|
||||||
--trust-remote-code \
|
--trust-remote-code \
|
||||||
--host "$HOST" \
|
--host "$HOST" \
|
||||||
|
|||||||
181
examples/online_serving/new_weight_syncing/rlhf_http_ipc.py
Normal file
181
examples/online_serving/new_weight_syncing/rlhf_http_ipc.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
|
||||||
|
via HTTP API, with IPC-based weight syncing APIs.
|
||||||
|
|
||||||
|
Unlike rlhf_nccl.py which uses NCCL and can use separate GPUs, this script
|
||||||
|
uses CUDA IPC which requires the training model and vLLM server to be on the
|
||||||
|
same GPU. Memory must be carefully managed to fit both models.
|
||||||
|
|
||||||
|
Unlike rlhf.py which creates a vLLM instance programmatically, this script
|
||||||
|
assumes you have already started a vLLM server using `vllm serve`. It uses:
|
||||||
|
- OpenAI-compatible API for inference requests
|
||||||
|
- HTTP endpoints for weight transfer control plane
|
||||||
|
- CUDA IPC for actual weight data transfer
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
Start a vLLM server with weight transfer enabled and reduced GPU memory
|
||||||
|
utilization to leave room for the training model:
|
||||||
|
|
||||||
|
$ VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_INSECURE_SERIALIZATION=1 \
|
||||||
|
vllm serve facebook/opt-125m --enforce-eager \
|
||||||
|
--weight-transfer-config '{"backend": "ipc"}' \
|
||||||
|
--load-format dummy \
|
||||||
|
--gpu-memory-utilization 0.5
|
||||||
|
|
||||||
|
Then run this script:
|
||||||
|
|
||||||
|
$ python rlhf_http_ipc.py
|
||||||
|
|
||||||
|
The example performs the following steps:
|
||||||
|
|
||||||
|
* Load the training model on GPU 0 (same GPU as the vLLM server).
|
||||||
|
* Generate text using the vLLM server via OpenAI-compatible API. The output
|
||||||
|
is expected to be nonsense because the server is initialized with dummy weights.
|
||||||
|
* Initialize weight transfer via HTTP endpoint (no-op for IPC).
|
||||||
|
* Broadcast the real weights from the training model to the vLLM server
|
||||||
|
using CUDA IPC handles.
|
||||||
|
* Generate text again to show normal output after the weight update.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||||
|
IPCTrainerSendWeightsArgs,
|
||||||
|
IPCWeightTransferEngine,
|
||||||
|
)
|
||||||
|
|
||||||
|
BASE_URL = "http://localhost:8000"
|
||||||
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
|
||||||
|
# Enable insecure serialization for IPC handle serialization
|
||||||
|
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
|
||||||
|
"""Generate completions using the OpenAI-compatible API."""
|
||||||
|
results = []
|
||||||
|
for prompt in prompts:
|
||||||
|
response = client.completions.create(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=32,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
results.append(response.choices[0].text)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def init_weight_transfer_engine(base_url: str) -> None:
|
||||||
|
"""Initialize weight transfer via HTTP endpoint (no-op for IPC)."""
|
||||||
|
url = f"{base_url}/init_weight_transfer_engine"
|
||||||
|
payload = {"init_info": dict()}
|
||||||
|
response = requests.post(url, json=payload, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
|
def pause_generation(base_url: str) -> None:
|
||||||
|
"""Pause generation via HTTP endpoint."""
|
||||||
|
url = f"{base_url}/pause"
|
||||||
|
response = requests.post(url, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
|
def resume_generation(base_url: str) -> None:
|
||||||
|
"""Resume generation via HTTP endpoint."""
|
||||||
|
url = f"{base_url}/resume"
|
||||||
|
response = requests.post(url, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size(base_url: str) -> int:
|
||||||
|
"""Get world size from the vLLM server."""
|
||||||
|
url = f"{base_url}/get_world_size"
|
||||||
|
response = requests.get(url, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["world_size"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# IPC requires the training model to be on the same GPU as the vLLM server
|
||||||
|
# The server should be started on GPU 0 with reduced memory utilization
|
||||||
|
device = "cuda:0"
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
# Load the training model on the same GPU as the server
|
||||||
|
# Use bfloat16 to reduce memory footprint
|
||||||
|
print(f"Loading training model: {MODEL_NAME} on {device}")
|
||||||
|
print(
|
||||||
|
"Note: Ensure the vLLM server was started with --gpu-memory-utilization 0.5 "
|
||||||
|
"or lower to leave room for the training model."
|
||||||
|
)
|
||||||
|
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
|
||||||
|
train_model.to(device)
|
||||||
|
train_model.eval() # Set to eval mode to save memory
|
||||||
|
|
||||||
|
# Create OpenAI client pointing to the vLLM server
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=f"{BASE_URL}/v1",
|
||||||
|
api_key="EMPTY", # vLLM doesn't require an API key by default
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test prompts
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate text before weight update. The output is expected to be nonsense
|
||||||
|
# because the server is initialized with dummy weights.
|
||||||
|
print("-" * 50)
|
||||||
|
print("Generating text BEFORE weight update (expect nonsense):")
|
||||||
|
print("-" * 50)
|
||||||
|
outputs = generate_completions(client, MODEL_NAME, prompts)
|
||||||
|
for prompt, generated_text in zip(prompts, outputs):
|
||||||
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
print("Initializing weight transfer (IPC backend)...")
|
||||||
|
|
||||||
|
# Initialize weight transfer on vLLM server (no-op for IPC, but still required)
|
||||||
|
init_weight_transfer_engine(BASE_URL)
|
||||||
|
|
||||||
|
# Pause generation before weight sync
|
||||||
|
pause_generation(BASE_URL)
|
||||||
|
|
||||||
|
# Broadcast weights via IPC handles using HTTP mode
|
||||||
|
print("Broadcasting weights via CUDA IPC (HTTP)...")
|
||||||
|
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
|
||||||
|
IPCWeightTransferEngine.trainer_send_weights(
|
||||||
|
iterator=train_model.named_parameters(),
|
||||||
|
trainer_args=trainer_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resume generation after weight sync
|
||||||
|
resume_generation(BASE_URL)
|
||||||
|
|
||||||
|
# Generate text after weight update. The output is expected to be normal
|
||||||
|
# because the real weights are now loaded.
|
||||||
|
print("-" * 50)
|
||||||
|
print("Generating text AFTER weight update:")
|
||||||
|
print("-" * 50)
|
||||||
|
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
|
||||||
|
for prompt, generated_text in zip(prompts, outputs_updated):
|
||||||
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# Note: The training model and IPC handles remain in memory.
|
||||||
|
# In a real RLHF training loop, you would update the training model
|
||||||
|
# and create new IPC handles for each weight update.
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -39,6 +39,7 @@ from openai import OpenAI
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||||
|
NCCLTrainerSendWeightsArgs,
|
||||||
NCCLWeightTransferEngine,
|
NCCLWeightTransferEngine,
|
||||||
)
|
)
|
||||||
from vllm.utils.network_utils import get_ip, get_open_port
|
from vllm.utils.network_utils import get_ip, get_open_port
|
||||||
@@ -214,11 +215,14 @@ def main():
|
|||||||
|
|
||||||
# Broadcast all weights from trainer to vLLM workers
|
# Broadcast all weights from trainer to vLLM workers
|
||||||
print("Broadcasting weights via NCCL...")
|
print("Broadcasting weights via NCCL...")
|
||||||
NCCLWeightTransferEngine.trainer_send_weights(
|
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||||
iterator=train_model.named_parameters(),
|
|
||||||
group=model_update_group,
|
group=model_update_group,
|
||||||
packed=True,
|
packed=True,
|
||||||
)
|
)
|
||||||
|
NCCLWeightTransferEngine.trainer_send_weights(
|
||||||
|
iterator=train_model.named_parameters(),
|
||||||
|
trainer_args=trainer_args,
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for update_weights to complete
|
# Wait for update_weights to complete
|
||||||
update_thread.join()
|
update_thread.join()
|
||||||
20
examples/pooling/embed/template/nemotron_embed_vl.jinja
Normal file
20
examples/pooling/embed/template/nemotron_embed_vl.jinja
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
{%- if messages | length > 1 -%}
|
||||||
|
{{ raise_exception('Embedding models should only embed one message at a time') }}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{% set vars = namespace(prefix='', images=[], texts=[]) %}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{%- if message['role'] == 'query' -%}
|
||||||
|
{%- set vars.prefix = 'query: ' %}
|
||||||
|
{%- elif message['role'] == 'document' -%}
|
||||||
|
{%- set vars.prefix = 'passage: ' %}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- for content in message['content'] -%}
|
||||||
|
{%- if content['type'] == 'text' -%}
|
||||||
|
{%- set vars.texts = vars.texts + [content['text']] %}
|
||||||
|
{%- elif content['type'] == 'image' -%}
|
||||||
|
{%- set vars.images = vars.images + ['<image> '] %}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- bos_token }}{{ vars.prefix }}{{ (vars.images + vars.texts) | join('') }}
|
||||||
15
examples/pooling/score/template/nemotron-vl-rerank.jinja
Normal file
15
examples/pooling/score/template/nemotron-vl-rerank.jinja
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{%- set query_msg = (messages | selectattr('role', 'equalto', 'query') | list | first) -%}
|
||||||
|
{%- set doc_msg = (messages | selectattr('role', 'equalto', 'document') | list | first) -%}
|
||||||
|
|
||||||
|
{%- set q = query_msg['content'] -%}
|
||||||
|
{%- set d = doc_msg['content'] -%}
|
||||||
|
|
||||||
|
{# If the doc contains <image> anywhere, hoist a single <image> to the front #}
|
||||||
|
{%- set has_image = ("<image>" in d) -%}
|
||||||
|
{%- set d_clean = d | replace("<image>", "") -%}
|
||||||
|
{%- set q_clean = q | replace("<image>", "") -%}
|
||||||
|
|
||||||
|
{%- if has_image -%}<image>{{ " " }}{%- endif -%}
|
||||||
|
question:{{ q_clean }}{{ " " }}
|
||||||
|
{{ " " }}
|
||||||
|
{{ " " }}passage:{{ d_clean }}
|
||||||
@@ -42,6 +42,7 @@ theme:
|
|||||||
- navigation.sections
|
- navigation.sections
|
||||||
- navigation.indexes
|
- navigation.indexes
|
||||||
- navigation.top
|
- navigation.top
|
||||||
|
- navigation.path
|
||||||
- search.highlight
|
- search.highlight
|
||||||
- search.share
|
- search.share
|
||||||
- toc.follow
|
- toc.follow
|
||||||
|
|||||||
@@ -117,7 +117,6 @@ markers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.ty.src]
|
[tool.ty.src]
|
||||||
root = "./vllm"
|
|
||||||
respect-ignore-files = true
|
respect-ignore-files = true
|
||||||
|
|
||||||
[tool.ty.environment]
|
[tool.ty.environment]
|
||||||
|
|||||||
@@ -57,3 +57,4 @@ opentelemetry-sdk >= 1.27.0
|
|||||||
opentelemetry-api >= 1.27.0
|
opentelemetry-api >= 1.27.0
|
||||||
opentelemetry-exporter-otlp >= 1.27.0
|
opentelemetry-exporter-otlp >= 1.27.0
|
||||||
opentelemetry-semantic-conventions-ai >= 0.4.1
|
opentelemetry-semantic-conventions-ai >= 0.4.1
|
||||||
|
kaldi-native-fbank >= 1.18.7
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user