Compare commits
63 Commits
v0.19.0rc0
...
v0.18.2rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6e636c12c | ||
|
|
f1ff50c86c | ||
|
|
757068dc65 | ||
|
|
7337ff7f03 | ||
|
|
5869f69c5f | ||
|
|
4dfad17ed1 | ||
|
|
e8057c00bc | ||
|
|
7430389669 | ||
|
|
202f147cf2 | ||
|
|
ea7bfde6e4 | ||
|
|
d71a15041f | ||
|
|
abdbb68386 | ||
|
|
0c63739135 | ||
|
|
719735d6c5 | ||
|
|
aae3e688f8 | ||
|
|
7d65463528 | ||
|
|
8278825b57 | ||
|
|
acf7292bf2 | ||
|
|
ce884756f0 | ||
|
|
d9d21eb8e3 | ||
|
|
f09daea261 | ||
|
|
42318c840b | ||
|
|
1ac6694297 | ||
|
|
6cc7abdc66 | ||
|
|
d53cb9cb8e | ||
|
|
44eef0ca1e | ||
|
|
b9cdc85207 | ||
|
|
3e802e8786 | ||
|
|
350af48e14 | ||
|
|
e31915063d | ||
|
|
29e48707e8 | ||
|
|
4ac227222f | ||
|
|
bb51d5b40d | ||
|
|
93b3ec1585 | ||
|
|
e812bf70bd | ||
|
|
bcc6f67447 | ||
|
|
1fc69f59bb | ||
|
|
d9c7db18da | ||
|
|
12701e8af2 | ||
|
|
494636b29d | ||
|
|
ab1a6a43fa | ||
|
|
b5e608258e | ||
|
|
2c734ed0e0 | ||
|
|
3b1dbaad4e | ||
|
|
b4a2f3ac36 | ||
|
|
8e6293e838 | ||
|
|
dbdd9ae067 | ||
|
|
e8b055a5ac | ||
|
|
246dc7d864 | ||
|
|
7c3f88b2a8 | ||
|
|
6557f4937f | ||
|
|
677424c7ac | ||
|
|
1031c84c36 | ||
|
|
7e76af14fa | ||
|
|
3683fe6c06 | ||
|
|
cc06b4e86b | ||
|
|
03ac6ca895 | ||
|
|
a08b7733fd | ||
|
|
85c0950b1f | ||
|
|
57861ae48d | ||
|
|
ac30a8311e | ||
|
|
63babd17f1 | ||
|
|
fec5aeca12 |
@@ -5,6 +5,7 @@ steps:
|
|||||||
depends_on: []
|
depends_on: []
|
||||||
device: amd_cpu
|
device: amd_cpu
|
||||||
no_plugin: true
|
no_plugin: true
|
||||||
|
soft_fail: true
|
||||||
commands:
|
commands:
|
||||||
- >
|
- >
|
||||||
docker build
|
docker build
|
||||||
@@ -20,11 +21,3 @@ steps:
|
|||||||
- docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}"
|
- docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
retry:
|
|
||||||
automatic:
|
|
||||||
- exit_status: -1 # Agent was lost
|
|
||||||
limit: 1
|
|
||||||
- exit_status: -10 # Agent was lost
|
|
||||||
limit: 1
|
|
||||||
- exit_status: 1 # Machine occasionally fail
|
|
||||||
limit: 1
|
|
||||||
|
|||||||
@@ -13,12 +13,14 @@ steps:
|
|||||||
- tests/kernels/attention/test_cpu_attn.py
|
- tests/kernels/attention/test_cpu_attn.py
|
||||||
- tests/kernels/moe/test_cpu_fused_moe.py
|
- tests/kernels/moe/test_cpu_fused_moe.py
|
||||||
- tests/kernels/test_onednn.py
|
- tests/kernels/test_onednn.py
|
||||||
|
- tests/kernels/test_awq_int4_to_int8.py
|
||||||
commands:
|
commands:
|
||||||
- |
|
- |
|
||||||
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
|
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
|
||||||
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
|
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
|
||||||
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
|
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
|
||||||
pytest -x -v -s tests/kernels/test_onednn.py"
|
pytest -x -v -s tests/kernels/test_onednn.py
|
||||||
|
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
|
||||||
|
|
||||||
- label: CPU-Compatibility Tests
|
- label: CPU-Compatibility Tests
|
||||||
depends_on: []
|
depends_on: []
|
||||||
|
|||||||
@@ -36,6 +36,7 @@
|
|||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"ignore-eos": "",
|
"ignore-eos": "",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -127,4 +128,4 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
"hf_split": "test",
|
"hf_split": "test",
|
||||||
"no_stream": "",
|
"no_stream": "",
|
||||||
"no_oversample": "",
|
"no_oversample": "",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -26,6 +26,7 @@
|
|||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"ignore-eos": "",
|
"ignore-eos": "",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -26,6 +26,7 @@
|
|||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"ignore-eos": "",
|
"ignore-eos": "",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -21,6 +21,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -47,6 +48,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -73,6 +75,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -100,6 +103,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -127,6 +131,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -151,6 +156,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -30,6 +31,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -47,6 +49,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -67,6 +70,7 @@
|
|||||||
"backend": "vllm",
|
"backend": "vllm",
|
||||||
"dataset_name": "sharegpt",
|
"dataset_name": "sharegpt",
|
||||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||||
|
"temperature": 0,
|
||||||
"num_prompts": 200
|
"num_prompts": 200
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -239,13 +239,29 @@ fi
|
|||||||
# --- Docker housekeeping ---
|
# --- Docker housekeeping ---
|
||||||
cleanup_docker
|
cleanup_docker
|
||||||
|
|
||||||
|
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin "$REGISTRY"
|
||||||
|
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 936637512419.dkr.ecr.us-east-1.amazonaws.com
|
||||||
|
|
||||||
# --- Build or pull test image ---
|
# --- Build or pull test image ---
|
||||||
if [[ -n "${IMAGE_TAG_XPU:-}" ]]; then
|
IMAGE="${IMAGE_TAG_XPU:-${image_name}}"
|
||||||
echo "Using prebuilt XPU image: ${IMAGE_TAG_XPU}"
|
|
||||||
docker pull "${IMAGE_TAG_XPU}"
|
echo "Using image: ${IMAGE}"
|
||||||
|
|
||||||
|
if docker image inspect "${IMAGE}" >/dev/null 2>&1; then
|
||||||
|
echo "Image already exists locally, skipping pull"
|
||||||
else
|
else
|
||||||
echo "Using prebuilt XPU image: ${image_name}"
|
echo "Image not found locally, waiting for lock..."
|
||||||
docker pull "${image_name}"
|
|
||||||
|
flock /tmp/docker-pull.lock bash -c "
|
||||||
|
if docker image inspect '${IMAGE}' >/dev/null 2>&1; then
|
||||||
|
echo 'Image already pulled by another runner'
|
||||||
|
else
|
||||||
|
echo 'Pulling image...'
|
||||||
|
timeout 900 docker pull '${IMAGE}'
|
||||||
|
fi
|
||||||
|
"
|
||||||
|
|
||||||
|
echo "Pull step completed"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
remove_docker_container() {
|
remove_docker_container() {
|
||||||
|
|||||||
@@ -2,14 +2,6 @@ group: Benchmarks
|
|||||||
depends_on:
|
depends_on:
|
||||||
- image-build
|
- image-build
|
||||||
steps:
|
steps:
|
||||||
- label: Benchmarks
|
|
||||||
timeout_in_minutes: 20
|
|
||||||
working_dir: "/vllm-workspace/.buildkite"
|
|
||||||
source_file_dependencies:
|
|
||||||
- benchmarks/
|
|
||||||
commands:
|
|
||||||
- bash scripts/run-benchmarks.sh
|
|
||||||
|
|
||||||
- label: Benchmarks CLI Test
|
- label: Benchmarks CLI Test
|
||||||
timeout_in_minutes: 20
|
timeout_in_minutes: 20
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ steps:
|
|||||||
- pytest -v -s distributed/test_eplb_algo.py
|
- pytest -v -s distributed/test_eplb_algo.py
|
||||||
- pytest -v -s distributed/test_eplb_utils.py
|
- pytest -v -s distributed/test_eplb_utils.py
|
||||||
|
|
||||||
- label: EPLB Execution
|
- label: EPLB Execution # 17min
|
||||||
timeout_in_minutes: 20
|
timeout_in_minutes: 27
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_devices: 4
|
num_devices: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
|||||||
26
.github/CODEOWNERS
vendored
26
.github/CODEOWNERS
vendored
@@ -2,14 +2,15 @@
|
|||||||
# for more info about CODEOWNERS file
|
# for more info about CODEOWNERS file
|
||||||
|
|
||||||
# This lists cover the "core" components of vLLM that require careful review
|
# This lists cover the "core" components of vLLM that require careful review
|
||||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng
|
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng @vadiklyutiy
|
||||||
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
|
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
|
||||||
/vllm/lora @jeejeelee
|
/vllm/lora @jeejeelee
|
||||||
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
|
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
|
||||||
/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety
|
/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety
|
||||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
|
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
|
||||||
/vllm/model_executor/layers/mamba @tdoublep
|
/vllm/model_executor/layers/mamba @tdoublep @tomeras91
|
||||||
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516
|
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516 @vadiklyutiy
|
||||||
|
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
|
||||||
/vllm/model_executor/model_loader @22quinn
|
/vllm/model_executor/model_loader @22quinn
|
||||||
/vllm/model_executor/layers/batch_invariant.py @yewentao256
|
/vllm/model_executor/layers/batch_invariant.py @yewentao256
|
||||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
|
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
|
||||||
@@ -47,9 +48,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
/vllm/v1/attention @LucasWilkinson @MatthewBonanni
|
/vllm/v1/attention @LucasWilkinson @MatthewBonanni
|
||||||
/vllm/v1/attention/backend.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill
|
/vllm/v1/attention/backend.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill
|
||||||
/vllm/v1/attention/backends/mla @pavanimajety
|
/vllm/v1/attention/backends/mla @pavanimajety
|
||||||
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety
|
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety @vadiklyutiy
|
||||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||||
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516
|
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516 @vadiklyutiy
|
||||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
|
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
|
||||||
/vllm/v1/sample @22quinn @houseroad @njhill
|
/vllm/v1/sample @22quinn @houseroad @njhill
|
||||||
/vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni
|
/vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni
|
||||||
@@ -71,7 +72,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||||
/tests/distributed/test_same_node.py @youkaichao
|
/tests/distributed/test_same_node.py @youkaichao
|
||||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
|
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
|
||||||
/tests/evals @mgoin
|
/tests/evals @mgoin @vadiklyutiy
|
||||||
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
||||||
/tests/models @DarkLight1337 @ywang96
|
/tests/models @DarkLight1337 @ywang96
|
||||||
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||||
@@ -82,7 +83,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
|
/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
|
||||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||||
/tests/lora @jeejeelee
|
/tests/lora @jeejeelee
|
||||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
/tests/models/language/generation/test_hybrid.py @tdoublep @tomeras91
|
||||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||||
/tests/v1/kv_connector @ApostaC @orozery
|
/tests/v1/kv_connector @ApostaC @orozery
|
||||||
/tests/v1/kv_offload @ApostaC @orozery
|
/tests/v1/kv_offload @ApostaC @orozery
|
||||||
@@ -126,9 +127,14 @@ mkdocs.yaml @hmellor
|
|||||||
/vllm/platforms/xpu.py @jikunshang
|
/vllm/platforms/xpu.py @jikunshang
|
||||||
/docker/Dockerfile.xpu @jikunshang
|
/docker/Dockerfile.xpu @jikunshang
|
||||||
|
|
||||||
|
# Nemotron-specific files
|
||||||
|
/vllm/model_executor/models/*nemotron* @tomeras91
|
||||||
|
/vllm/transformers_utils/configs/*nemotron* @tomeras91
|
||||||
|
/tests/**/*nemotron* @tomeras91
|
||||||
|
|
||||||
# Qwen-specific files
|
# Qwen-specific files
|
||||||
/vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow
|
/vllm/model_executor/models/qwen* @sighingnow @vadiklyutiy
|
||||||
/vllm/model_executor/models/qwen* @sighingnow
|
/vllm/transformers_utils/configs/qwen* @sighingnow @vadiklyutiy
|
||||||
|
|
||||||
# MTP-specific files
|
# MTP-specific files
|
||||||
/vllm/model_executor/models/deepseek_mtp.py @luccafong
|
/vllm/model_executor/models/deepseek_mtp.py @luccafong
|
||||||
@@ -144,7 +150,7 @@ mkdocs.yaml @hmellor
|
|||||||
# Kernels
|
# Kernels
|
||||||
/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep
|
/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep
|
||||||
/vllm/v1/attention/ops/triton_unified_attention.py @tdoublep
|
/vllm/v1/attention/ops/triton_unified_attention.py @tdoublep
|
||||||
/vllm/model_executor/layers/fla @ZJY0516
|
/vllm/model_executor/layers/fla @ZJY0516 @vadiklyutiy
|
||||||
|
|
||||||
# ROCm related: specify owner with write access to notify AMD folks for careful code review
|
# ROCm related: specify owner with write access to notify AMD folks for careful code review
|
||||||
/vllm/**/*rocm* @tjtanaa
|
/vllm/**/*rocm* @tjtanaa
|
||||||
|
|||||||
420
CMakeLists.txt
420
CMakeLists.txt
@@ -309,7 +309,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||||
|
|
||||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||||
set(CUTLASS_REVISION "v4.2.1")
|
set(CUTLASS_REVISION "v4.4.2")
|
||||||
|
|
||||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||||
@@ -340,7 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
|
|
||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||||
"csrc/cutlass_extensions/common.cpp")
|
"csrc/cutlass_extensions/common.cpp")
|
||||||
@@ -490,132 +489,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
" in CUDA target architectures")
|
" in CUDA target architectures")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
set(SCALED_MM_3X_ARCHS)
|
|
||||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
|
||||||
# CUDA 12.0 or later
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
|
||||||
set(SRCS
|
|
||||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
|
|
||||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
|
|
||||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
|
|
||||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
|
|
||||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
|
||||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
|
||||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
|
||||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
|
||||||
else()
|
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
|
||||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
|
||||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
|
||||||
"later if you intend on running FP8 quantized models on "
|
|
||||||
"Hopper.")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
|
||||||
"in CUDA target architectures")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
|
|
||||||
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
|
|
||||||
# CUDA 12.8 or later
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
|
|
||||||
else()
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
|
||||||
endif()
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
|
||||||
set(SRCS
|
|
||||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
|
|
||||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
|
|
||||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
|
||||||
)
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
|
||||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
|
||||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
|
||||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
|
||||||
else()
|
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
|
||||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
|
||||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
|
||||||
"later if you intend on running FP8 quantized models on "
|
|
||||||
"Blackwell.")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
|
||||||
"in CUDA target architectures")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
|
|
||||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
|
||||||
# require CUDA 12.8 or later
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
|
||||||
else()
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
|
||||||
endif()
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
|
||||||
set(SRCS
|
|
||||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
|
|
||||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
|
|
||||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
|
||||||
)
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
|
||||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
|
||||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
|
||||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
|
||||||
else()
|
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
|
||||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
|
||||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
|
||||||
"later if you intend on running FP8 quantized models on "
|
|
||||||
"Blackwell.")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
|
||||||
"in CUDA target architectures")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
|
||||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
|
||||||
# kernels for the remaining archs that are not already built for 3x.
|
|
||||||
# (Build 8.9 for FP8)
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
|
||||||
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
|
|
||||||
# subtract out the archs that are already built for 3x
|
|
||||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
|
||||||
if (SCALED_MM_2X_ARCHS)
|
|
||||||
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
|
||||||
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
|
||||||
else()
|
|
||||||
if (SCALED_MM_3X_ARCHS)
|
|
||||||
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
|
||||||
" for and covered by scaled_mm_c3x")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
|
||||||
"in CUDA target architectures")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
||||||
# CUDA 12.8 or later
|
# CUDA 12.8 or later
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
@@ -693,55 +566,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set(MLA_ARCHS)
|
set(MLA_ARCHS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# CUTLASS MoE kernels
|
|
||||||
|
|
||||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
|
|
||||||
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
|
|
||||||
# if it's possible to compile MoE kernels that use its output.
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
|
||||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
|
||||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
|
||||||
else()
|
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
|
||||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
|
||||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
|
||||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
|
||||||
"in CUDA target architectures.")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
|
||||||
else()
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
|
||||||
endif()
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
|
||||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
|
||||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
|
||||||
else()
|
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
|
||||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
|
||||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
|
||||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
|
||||||
"in CUDA target architectures.")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
|
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||||
@@ -787,36 +611,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"in CUDA target architectures.")
|
"in CUDA target architectures.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
|
||||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
|
||||||
else()
|
|
||||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
|
||||||
endif()
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
|
||||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu")
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
|
||||||
else()
|
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
|
||||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
|
||||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
|
||||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building moe_data as no compatible archs found "
|
|
||||||
"in CUDA target architectures.")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
|
||||||
else()
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Machete kernels
|
# Machete kernels
|
||||||
|
|
||||||
@@ -964,7 +758,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
|
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
|
||||||
#
|
#
|
||||||
set(VLLM_STABLE_EXT_SRC
|
set(VLLM_STABLE_EXT_SRC
|
||||||
"csrc/libtorch_stable/torch_bindings.cpp")
|
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||||
|
"csrc/cutlass_extensions/common.cpp"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
list(APPEND VLLM_STABLE_EXT_SRC
|
list(APPEND VLLM_STABLE_EXT_SRC
|
||||||
@@ -979,6 +775,209 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch)
|
||||||
|
#
|
||||||
|
set(SCALED_MM_3X_ARCHS)
|
||||||
|
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||||
|
# CUDA 12.0 or later
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||||
|
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||||
|
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||||
|
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||||
|
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||||
|
"later if you intend on running FP8 quantized models on "
|
||||||
|
"Hopper.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
|
||||||
|
# CUDA 12.8 or later
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||||
|
)
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||||
|
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||||
|
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
|
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||||
|
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||||
|
"later if you intend on running FP8 quantized models on "
|
||||||
|
"Blackwell.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||||
|
# require CUDA 12.8 or later
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
||||||
|
)
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||||
|
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||||
|
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
|
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||||
|
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||||
|
"later if you intend on running FP8 quantized models on "
|
||||||
|
"Blackwell.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||||
|
# kernels for the remaining archs that are not already built for 3x.
|
||||||
|
# (Build 8.9 for FP8)
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||||
|
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
|
||||||
|
# subtract out the archs that are already built for 3x
|
||||||
|
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||||
|
if (SCALED_MM_2X_ARCHS)
|
||||||
|
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
||||||
|
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (SCALED_MM_3X_ARCHS)
|
||||||
|
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
||||||
|
" for and covered by scaled_mm_c3x")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# CUTLASS MoE kernels (moved from _C to _C_stable_libtorch)
|
||||||
|
#
|
||||||
|
|
||||||
|
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
|
||||||
|
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
|
||||||
|
# if it's possible to compile MoE kernels that use its output.
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||||
|
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||||
|
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||||
|
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||||
|
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||||
|
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||||
|
"in CUDA target architectures.")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
|
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||||
|
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
|
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||||
|
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||||
|
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||||
|
"in CUDA target architectures.")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||||
|
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||||
|
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||||
|
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||||
|
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building moe_data as no compatible archs found "
|
||||||
|
"in CUDA target architectures.")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
message(STATUS "Enabling C_stable extension.")
|
message(STATUS "Enabling C_stable extension.")
|
||||||
define_extension_target(
|
define_extension_target(
|
||||||
_C_stable_libtorch
|
_C_stable_libtorch
|
||||||
@@ -987,6 +986,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
SOURCES ${VLLM_STABLE_EXT_SRC}
|
SOURCES ${VLLM_STABLE_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||||
USE_SABI 3
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
@@ -1000,6 +1000,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# Needed to use cuda APIs from C-shim
|
# Needed to use cuda APIs from C-shim
|
||||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||||
USE_CUDA)
|
USE_CUDA)
|
||||||
|
|
||||||
|
# Needed by CUTLASS kernels
|
||||||
|
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||||
|
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -373,6 +373,7 @@ if (ENABLE_X86_ISA)
|
|||||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||||
|
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
|
||||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
|
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
|
||||||
|
|||||||
@@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int get_thread_num() {
|
||||||
|
#if defined(_OPENMP)
|
||||||
|
return omp_get_thread_num();
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// for 1d parallel, use `actual_nth`
|
// for 1d parallel, use `actual_nth`
|
||||||
// for 2d parallel, use even nths, e.g. 43->42
|
// for 2d parallel, use even nths, e.g. 43->42
|
||||||
int inline adjust_num_threads(int m) {
|
int inline adjust_num_threads(int m) {
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
|
|||||||
template <typename T> inline bool can_use_brgemm(int M);
|
template <typename T> inline bool can_use_brgemm(int M);
|
||||||
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
|
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
|
||||||
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
|
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
|
||||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
template <> inline bool can_use_brgemm<int8_t>(int M) { return M > 4; }
|
||||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
|
template <> inline bool can_use_brgemm<uint8_t>(int M) { return M > 4; }
|
||||||
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
|
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
|
||||||
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
|
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
|
||||||
|
|
||||||
@@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
|||||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||||
}
|
}
|
||||||
|
|
||||||
// pack weight to vnni format
|
inline int64_t get_4bit_block_k_size(int64_t group_size) {
|
||||||
|
return group_size > 128 ? 128 : group_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// pack weight into vnni format
|
||||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||||
|
|
||||||
|
// pack weight to vnni format for int4 (adapted from sglang)
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||||
|
convert_weight_packed_scale_zp(at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
|
||||||
|
|
||||||
// moe implementations for int8 w8a8
|
// moe implementations for int8 w8a8
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void fused_experts_int8_kernel_impl(
|
void fused_experts_int8_kernel_impl(
|
||||||
@@ -233,6 +241,31 @@ void tinygemm_kernel(
|
|||||||
int64_t strideBs,
|
int64_t strideBs,
|
||||||
bool brg);
|
bool brg);
|
||||||
|
|
||||||
|
// int4 scaled GEMM (adapted from sglang)
|
||||||
|
at::Tensor int4_scaled_mm_cpu(
|
||||||
|
at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, at::Tensor& w_scales, std::optional<at::Tensor> bias);
|
||||||
|
|
||||||
|
// int4 tinygemm kernel interface(adapted from sglang)
|
||||||
|
template <typename scalar_t>
|
||||||
|
void tinygemm_kernel(
|
||||||
|
scalar_t* C,
|
||||||
|
float* C_temp,
|
||||||
|
const uint8_t* A,
|
||||||
|
const float* scales_a,
|
||||||
|
const int32_t* qzeros_a,
|
||||||
|
const uint8_t* B,
|
||||||
|
const float* scales_b,
|
||||||
|
const int8_t* qzeros_b,
|
||||||
|
const int32_t* compensation,
|
||||||
|
int8_t* dqB_tmp,
|
||||||
|
int64_t M,
|
||||||
|
int64_t K,
|
||||||
|
int64_t lda,
|
||||||
|
int64_t ldc_f,
|
||||||
|
int64_t ldc_s,
|
||||||
|
bool store_out,
|
||||||
|
bool use_brgemm);
|
||||||
|
|
||||||
// TODO: debug print, remove me later
|
// TODO: debug print, remove me later
|
||||||
inline void print_16x32i(const __m512i x) {
|
inline void print_16x32i(const __m512i x) {
|
||||||
int32_t a[16];
|
int32_t a[16];
|
||||||
|
|||||||
755
csrc/cpu/sgl-kernels/gemm_int4.cpp
Normal file
755
csrc/cpu/sgl-kernels/gemm_int4.cpp
Normal file
@@ -0,0 +1,755 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Adapted from sgl-project/sglang
|
||||||
|
// https://github.com/sgl-project/sglang/pull/8226
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "gemm.h"
|
||||||
|
#include "vec.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#define BLOCK_N block_size_n()
|
||||||
|
#define BLOCK_M 128
|
||||||
|
|
||||||
|
template <bool sym_quant_act>
|
||||||
|
struct ActDtype;
|
||||||
|
template <>
|
||||||
|
struct ActDtype<true> {
|
||||||
|
using type = int8_t;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct ActDtype<false> {
|
||||||
|
using type = uint8_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct alignas(32) m256i_wrapper {
|
||||||
|
__m256i data;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
inline std::array<m256i_wrapper, 2> load_zps_4vnni(
|
||||||
|
const int8_t* __restrict__ zps) {
|
||||||
|
__m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps));
|
||||||
|
__m256i vzps_high =
|
||||||
|
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps + 8));
|
||||||
|
__m256i shuffle_mask =
|
||||||
|
_mm256_set_epi8(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3,
|
||||||
|
3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
|
||||||
|
vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask);
|
||||||
|
vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask);
|
||||||
|
m256i_wrapper vzps_low_wp, vzps_high_wp;
|
||||||
|
vzps_low_wp.data = vzps_low;
|
||||||
|
vzps_high_wp.data = vzps_high;
|
||||||
|
return {vzps_low_wp, vzps_high_wp};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::array<m256i_wrapper, 2> load_uint4_as_int8(
|
||||||
|
const uint8_t* __restrict__ qB) {
|
||||||
|
__m256i packed = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(qB));
|
||||||
|
const __m256i low_mask = _mm256_set1_epi8(0x0f);
|
||||||
|
__m256i high = _mm256_srli_epi16(packed, 4);
|
||||||
|
high = _mm256_and_si256(high, low_mask);
|
||||||
|
__m256i low = _mm256_and_si256(packed, low_mask);
|
||||||
|
m256i_wrapper low_wp, high_wp;
|
||||||
|
low_wp.data = low;
|
||||||
|
high_wp.data = high;
|
||||||
|
return {low_wp, high_wp};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, int ldb>
|
||||||
|
void _dequant_weight_zp_only(const uint8_t* __restrict__ B, int8_t* dqB,
|
||||||
|
const int8_t* __restrict__ qzeros, int64_t K) {
|
||||||
|
#pragma GCC unroll 2
|
||||||
|
for (int n = 0; n < N; n += 16) {
|
||||||
|
auto [zps_low_wp, zps_high_wp] = load_zps_4vnni(&qzeros[n]);
|
||||||
|
auto zps_low = zps_low_wp.data;
|
||||||
|
auto zps_high = zps_high_wp.data;
|
||||||
|
for (int k = 0; k < K; k += 4) {
|
||||||
|
auto [vb_low_wp, vb_high_wp] =
|
||||||
|
load_uint4_as_int8(B + ldb * k + n / 2 * 4);
|
||||||
|
auto vb_low = vb_low_wp.data;
|
||||||
|
auto vb_high = vb_high_wp.data;
|
||||||
|
vb_high = _mm256_sub_epi8(vb_high, zps_high);
|
||||||
|
vb_low = _mm256_sub_epi8(vb_low, zps_low);
|
||||||
|
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4),
|
||||||
|
vb_low);
|
||||||
|
_mm256_storeu_si256(
|
||||||
|
reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool sym_quant_act, int N, bool accum>
|
||||||
|
void _dequant_and_store(float* __restrict__ output,
|
||||||
|
const int32_t* __restrict__ input,
|
||||||
|
const float* __restrict__ scale_a,
|
||||||
|
const int32_t* __restrict__ zp_a,
|
||||||
|
const float* __restrict__ scale_b,
|
||||||
|
const int32_t* __restrict__ comp_b, int M, int ldi,
|
||||||
|
int ldo, int ldsa = 1) {
|
||||||
|
for (int m = 0; m < M; ++m) {
|
||||||
|
float a_scale = *(scale_a + m * ldsa);
|
||||||
|
__m512 va_scale = _mm512_set1_ps(a_scale);
|
||||||
|
int32_t a_zp;
|
||||||
|
__m512i va_zp;
|
||||||
|
if constexpr (!sym_quant_act) {
|
||||||
|
a_zp = *(zp_a + m * ldsa);
|
||||||
|
va_zp = _mm512_set1_epi32(a_zp);
|
||||||
|
}
|
||||||
|
int n = 0;
|
||||||
|
#pragma GCC unroll 2
|
||||||
|
for (; n < N; n += 16) {
|
||||||
|
__m512i vc = _mm512_loadu_si512(input + m * ldi + n);
|
||||||
|
if constexpr (!sym_quant_act) {
|
||||||
|
__m512i vb_comp = _mm512_loadu_si512(comp_b + n);
|
||||||
|
vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp));
|
||||||
|
}
|
||||||
|
__m512 vc_f = _mm512_cvtepi32_ps(vc);
|
||||||
|
__m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale);
|
||||||
|
__m512 vb_s = _mm512_loadu_ps(scale_b + n);
|
||||||
|
vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s);
|
||||||
|
if constexpr (accum) {
|
||||||
|
__m512 vo = _mm512_loadu_ps(output + m * ldo + n);
|
||||||
|
_mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul));
|
||||||
|
} else {
|
||||||
|
_mm512_storeu_ps(output + m * ldo + n, vc_f_mul);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; n < N; ++n) {
|
||||||
|
float dq_val;
|
||||||
|
if constexpr (sym_quant_act) {
|
||||||
|
dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n];
|
||||||
|
} else {
|
||||||
|
dq_val = (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale *
|
||||||
|
scale_b[n];
|
||||||
|
}
|
||||||
|
if constexpr (accum) {
|
||||||
|
output[m * ldo + n] += dq_val;
|
||||||
|
} else {
|
||||||
|
output[m * ldo + n] = dq_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
template <int N, int ldb>
|
||||||
|
void _dequant_weight_zp_only(const uint8_t* B, int8_t* dqB,
|
||||||
|
const int8_t* qzeros, int64_t K) {
|
||||||
|
for (int k = 0; k < K; ++k) {
|
||||||
|
for (int n = 0; n < N / 2; ++n) {
|
||||||
|
int32_t b = (int32_t)B[k * ldb + n];
|
||||||
|
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n];
|
||||||
|
dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
inline __m512i combine_m256i(__m256i a, __m256i b) {
|
||||||
|
__m512i c = _mm512_castsi256_si512(a);
|
||||||
|
return _mm512_inserti64x4(c, b, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m512i combine_m256i(std::array<m256i_wrapper, 2> two_256) {
|
||||||
|
return combine_m256i(two_256[0].data, two_256[1].data);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) {
|
||||||
|
__m512i zero = _mm512_setzero_si512();
|
||||||
|
__mmask64 blt0 = _mm512_movepi8_mask(b);
|
||||||
|
return _mm512_mask_sub_epi8(a, blt0, zero, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool sym_quant_act, int M, int N, int ldb>
|
||||||
|
void _dequant_gemm_accum_small_M(float* __restrict__ C, const uint8_t* A,
|
||||||
|
const float* scales_a, const int32_t* qzeros_a,
|
||||||
|
const uint8_t* B, const float* scales_b,
|
||||||
|
const int8_t* qzeros_b, int64_t K, int64_t lda,
|
||||||
|
int64_t ldc) {
|
||||||
|
constexpr int COLS = N / 16;
|
||||||
|
__m512i ones = _mm512_set1_epi8(1);
|
||||||
|
__m512i va;
|
||||||
|
__m512i vb[COLS];
|
||||||
|
__m512i vc[M * COLS];
|
||||||
|
__m512 vscales[COLS];
|
||||||
|
__m512i vzps[COLS];
|
||||||
|
__m512i vcompensate[COLS];
|
||||||
|
|
||||||
|
Unroll<COLS>{}([&](auto i) {
|
||||||
|
vscales[i] = _mm512_loadu_ps(scales_b + i * 16);
|
||||||
|
vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16));
|
||||||
|
if constexpr (!sym_quant_act) {
|
||||||
|
vcompensate[i] = _mm512_setzero_epi32();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Unroll<M * COLS>{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); });
|
||||||
|
|
||||||
|
auto compute = [&](auto i, int k) {
|
||||||
|
constexpr const int row = i / COLS;
|
||||||
|
constexpr const int col = i % COLS;
|
||||||
|
|
||||||
|
if constexpr (col == 0) {
|
||||||
|
va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k));
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (row == 0) {
|
||||||
|
int B_offset = k * ldb + col * 16 * 2;
|
||||||
|
vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset));
|
||||||
|
vb[col] = _mm512_sub_epi8(vb[col], vzps[col]);
|
||||||
|
if constexpr (!sym_quant_act) {
|
||||||
|
vcompensate[col] = _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]);
|
||||||
|
}
|
||||||
|
_mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0);
|
||||||
|
}
|
||||||
|
if constexpr (sym_quant_act) {
|
||||||
|
auto vsb = _mm512_sign_epi8(vb[col], va);
|
||||||
|
auto vabsa = _mm512_sign_epi8(va, va);
|
||||||
|
vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb);
|
||||||
|
} else {
|
||||||
|
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr const int unroll = 4;
|
||||||
|
int k = 0;
|
||||||
|
for (; k < K / 4 / unroll; k++) {
|
||||||
|
Unroll<unroll>{}(
|
||||||
|
[&](auto i) { Unroll<M * COLS>{}(compute, 4 * (k * unroll + i)); });
|
||||||
|
}
|
||||||
|
k *= 4 * unroll;
|
||||||
|
for (; k < K; k += 4) {
|
||||||
|
Unroll<M * COLS>{}(compute, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto store = [&](auto i) {
|
||||||
|
constexpr const int row = i / COLS;
|
||||||
|
constexpr const int col = i % COLS;
|
||||||
|
__m512 vc_float;
|
||||||
|
if constexpr (!sym_quant_act) {
|
||||||
|
vc[i] = _mm512_sub_epi32(
|
||||||
|
vc[i], _mm512_mullo_epi32(vcompensate[col],
|
||||||
|
_mm512_set1_epi32(*(qzeros_a + row))));
|
||||||
|
}
|
||||||
|
vc_float = _mm512_cvtepi32_ps(vc[i]);
|
||||||
|
vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row)));
|
||||||
|
|
||||||
|
vc_float = _mm512_mul_ps(vc_float, vscales[col]);
|
||||||
|
auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16);
|
||||||
|
vc_float = _mm512_add_ps(vc_float, vc_old);
|
||||||
|
_mm512_storeu_ps(C + row * ldc + col * 16, vc_float);
|
||||||
|
};
|
||||||
|
Unroll<M * COLS>{}(store);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CALL_DEQUANT_GEMM_ACCUM_SMALL_M(M) \
|
||||||
|
_dequant_gemm_accum_small_M<sym_quant_act, M, N, ldb>( \
|
||||||
|
C, A, scales_a, qzeros_a, B, scales_b, qzeros_b, K, lda, ldc);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <bool sym_quant_act, int N, int ldb>
|
||||||
|
void _dequant_gemm_accum(float* C, const uint8_t* A, const float* scales_a,
|
||||||
|
const int32_t* qzeros_a, const uint8_t* B,
|
||||||
|
const float* scales_b, const int8_t* qzeros_b,
|
||||||
|
const int32_t* compensation, int8_t* dqB, int64_t M,
|
||||||
|
int64_t K, int64_t lda, int64_t ldc, bool use_brgemm) {
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
if (!use_brgemm) {
|
||||||
|
switch (M) {
|
||||||
|
case 1:
|
||||||
|
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(1);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(2);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(3);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(4);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "tinygemm_kernel: unexpected M for AVX path!");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
_dequant_weight_zp_only<N, ldb>(B, dqB, qzeros_b, K);
|
||||||
|
using Tin = typename ActDtype<sym_quant_act>::type;
|
||||||
|
Tin* A_ptr = (Tin*)A;
|
||||||
|
if (use_brgemm) {
|
||||||
|
int32_t C_i32[M * N];
|
||||||
|
at::native::cpublas::brgemm(M, N, K, lda, N /*ldb*/, N /*ldc*/,
|
||||||
|
false /* add_C */, A_ptr, dqB, C_i32,
|
||||||
|
true /* is_vnni */);
|
||||||
|
_mm_prefetch(B + N * K / 2, _MM_HINT_T0);
|
||||||
|
_mm_prefetch(A + K, _MM_HINT_T0);
|
||||||
|
_dequant_and_store<sym_quant_act, N, true>(C, C_i32, scales_a, qzeros_a,
|
||||||
|
scales_b, compensation, M,
|
||||||
|
N /*ldi*/, ldc, 1 /*ldsa*/);
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
TORCH_CHECK(false, "tinygemm_kernel: scalar path not implemented!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) {
|
||||||
|
if (bias_ptr) {
|
||||||
|
for (int i = 0; i < m; ++i) {
|
||||||
|
int j = 0;
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
#pragma GCC unroll 2
|
||||||
|
for (; j < N; j += 16) {
|
||||||
|
__m512 bias_vec = _mm512_loadu_ps(bias_ptr + j);
|
||||||
|
_mm512_storeu_ps(y_buf + i * N + j, bias_vec);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; j < N; ++j) {
|
||||||
|
y_buf[i * N + j] = bias_ptr[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < m; ++i) {
|
||||||
|
int j = 0;
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
#pragma GCC unroll 2
|
||||||
|
for (; j < N; j += 16) {
|
||||||
|
__m512 zero_vec = _mm512_setzero_ps();
|
||||||
|
_mm512_storeu_ps(y_buf + i * N + j, zero_vec);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; j < N; ++j) {
|
||||||
|
y_buf[i * N + j] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N, typename out_dtype>
|
||||||
|
inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m,
|
||||||
|
int64_t lda) {
|
||||||
|
for (int i = 0; i < m; ++i) {
|
||||||
|
int j = 0;
|
||||||
|
if constexpr (std::is_same<out_dtype, float>::value) {
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
#pragma GCC unroll 2
|
||||||
|
for (; j < N; j += 16) {
|
||||||
|
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||||
|
_mm512_storeu_ps(c_ptr + i * lda + j, y_vec);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; j < N; ++j) {
|
||||||
|
c_ptr[i * lda + j] = y_buf[i * N + j];
|
||||||
|
}
|
||||||
|
} else if constexpr (std::is_same<out_dtype, at::BFloat16>::value) {
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
#pragma GCC unroll 2
|
||||||
|
for (; j < N; j += 16) {
|
||||||
|
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||||
|
__m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec);
|
||||||
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
|
||||||
|
y_bf16_vec);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; j < N; ++j) {
|
||||||
|
c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]);
|
||||||
|
}
|
||||||
|
} else if constexpr (std::is_same<out_dtype, at::Half>::value) {
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
#pragma GCC unroll 2
|
||||||
|
for (; j < N; j += 16) {
|
||||||
|
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||||
|
__m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec);
|
||||||
|
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
|
||||||
|
y_fp16_vec);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; j < N; ++j) {
|
||||||
|
c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported output dtype");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_val_stub(int32_t* __restrict__ output, int32_t value, int64_t size) {
|
||||||
|
using iVec = at::vec::Vectorized<int32_t>;
|
||||||
|
constexpr int VecSize = iVec::size();
|
||||||
|
const iVec fill_val_vec = iVec(value);
|
||||||
|
int64_t d;
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (d = 0; d <= size - VecSize; d += VecSize) {
|
||||||
|
fill_val_vec.store(output + d);
|
||||||
|
}
|
||||||
|
for (; d < size; ++d) {
|
||||||
|
output[d] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool sym_quant_act, typename act_dtype, typename out_dtype>
|
||||||
|
void _da8w4_linear_impl(
|
||||||
|
act_dtype* __restrict__ input, const float* __restrict__ input_scales,
|
||||||
|
const int32_t* __restrict__ input_qzeros,
|
||||||
|
const uint8_t* __restrict__ weight, const float* __restrict__ weight_scales,
|
||||||
|
const int8_t* __restrict__ weight_qzeros, const float* __restrict__ bias,
|
||||||
|
out_dtype* __restrict__ output, float* __restrict__ output_temp,
|
||||||
|
int8_t* __restrict__ dequant_weight_temp, int64_t M, int64_t N, int64_t K,
|
||||||
|
int64_t num_groups) {
|
||||||
|
const bool use_brgemm = can_use_brgemm<act_dtype>(M);
|
||||||
|
int64_t block_m = [&]() -> long {
|
||||||
|
if (M <= 48) {
|
||||||
|
return M;
|
||||||
|
} else if (M < 64) {
|
||||||
|
return 32;
|
||||||
|
} else if (M < 96) {
|
||||||
|
return 64;
|
||||||
|
} else {
|
||||||
|
return 128;
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
int64_t Mc = div_up(M, block_m);
|
||||||
|
bool parallel_on_M = M > 128;
|
||||||
|
int64_t Nc = N / BLOCK_N;
|
||||||
|
int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc;
|
||||||
|
int64_t group_size = div_up(K, num_groups);
|
||||||
|
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||||
|
int64_t Kc = K / _block_k;
|
||||||
|
int64_t block_per_group = group_size / _block_k;
|
||||||
|
|
||||||
|
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
|
||||||
|
int tid = get_thread_num();
|
||||||
|
float* C_tmp = output_temp + tid * block_m * BLOCK_N;
|
||||||
|
int8_t* dqB_tmp = dequant_weight_temp + tid * _block_k * BLOCK_N;
|
||||||
|
for (const auto i : c10::irange(begin, end)) {
|
||||||
|
int64_t mc = parallel_on_M ? i / Nc : 0;
|
||||||
|
int64_t nc = parallel_on_M ? i % Nc : i;
|
||||||
|
int64_t mc_end = parallel_on_M ? mc + 1 : Mc;
|
||||||
|
|
||||||
|
for (int mci = mc; mci < mc_end; ++mci) {
|
||||||
|
int64_t m_size =
|
||||||
|
mci * block_m + block_m > M ? M - mci * block_m : block_m;
|
||||||
|
auto bias_data = bias ? bias + nc * BLOCK_N : nullptr;
|
||||||
|
copy_bias<BLOCK_N>(bias_data, C_tmp, m_size);
|
||||||
|
for (int kci = 0; kci < Kc; ++kci) {
|
||||||
|
int32_t* compensation_ptr =
|
||||||
|
sym_quant_act
|
||||||
|
? nullptr
|
||||||
|
: (int32_t*)(void*)(weight +
|
||||||
|
(nc * Kc + kci) *
|
||||||
|
(BLOCK_N *
|
||||||
|
(_block_k / 2 + sizeof(int32_t))) +
|
||||||
|
_block_k * BLOCK_N / 2);
|
||||||
|
_dequant_gemm_accum<sym_quant_act, BLOCK_N, BLOCK_N / 2>(
|
||||||
|
/*C*/ C_tmp,
|
||||||
|
/*A*/ (uint8_t*)input + mci * block_m * K + kci * _block_k,
|
||||||
|
/*scales_a*/ input_scales + mci * block_m,
|
||||||
|
/*qzeros_a*/ input_qzeros + mci * block_m,
|
||||||
|
/*B*/ weight + (nc * Kc + kci) *
|
||||||
|
(BLOCK_N * (_block_k / 2 + sizeof(int32_t))),
|
||||||
|
/*scales_b*/ weight_scales + nc * BLOCK_N * num_groups +
|
||||||
|
kci / block_per_group * BLOCK_N,
|
||||||
|
/*qzeros_b*/ weight_qzeros + nc * BLOCK_N * num_groups +
|
||||||
|
kci / block_per_group * BLOCK_N,
|
||||||
|
/*Bcomp*/ compensation_ptr,
|
||||||
|
/*dqB_tmp*/ dqB_tmp,
|
||||||
|
/*M*/ m_size,
|
||||||
|
/*K*/ _block_k,
|
||||||
|
/*lda*/ K,
|
||||||
|
/*ldc*/ BLOCK_N,
|
||||||
|
/*use_brgemm*/ use_brgemm);
|
||||||
|
}
|
||||||
|
store_out<BLOCK_N>(C_tmp, output + mci * block_m * N + nc * BLOCK_N,
|
||||||
|
m_size, N /*lda*/);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (use_brgemm) {
|
||||||
|
at::native::cpublas::brgemm_release();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||||
|
convert_int4_weight_packed_with_compensation(const at::Tensor& weight,
|
||||||
|
const at::Tensor& scales,
|
||||||
|
const at::Tensor& qzeros) {
|
||||||
|
TORCH_CHECK(weight.dim() == 2,
|
||||||
|
"DA8W4 CPU: Weight should be a 2D tensor for packing");
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight.size(1) % 2 == 0,
|
||||||
|
"DA8W4 CPU: Weight should have even number of columns for packing");
|
||||||
|
|
||||||
|
auto new_scales = scales;
|
||||||
|
auto new_qzeros = qzeros;
|
||||||
|
if (new_scales.dim() == 1) {
|
||||||
|
new_scales.unsqueeze_(1);
|
||||||
|
}
|
||||||
|
new_scales = new_scales.to(at::kFloat);
|
||||||
|
if (new_qzeros.dim() == 1) {
|
||||||
|
new_qzeros.unsqueeze_(1);
|
||||||
|
}
|
||||||
|
new_qzeros = new_qzeros.to(at::kChar);
|
||||||
|
int64_t N = weight.size(0);
|
||||||
|
int64_t K = weight.size(1);
|
||||||
|
int64_t G = scales.size(1);
|
||||||
|
int64_t group_size = K / G;
|
||||||
|
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||||
|
constexpr int block_n = block_size_n();
|
||||||
|
int64_t Nc = N / block_n;
|
||||||
|
int64_t Kc = K / _block_k;
|
||||||
|
|
||||||
|
auto weight_view = weight.view({Nc, block_n, Kc, _block_k});
|
||||||
|
at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous();
|
||||||
|
at::Tensor blocked_weight;
|
||||||
|
at::Tensor blocked_scales =
|
||||||
|
new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
|
||||||
|
at::Tensor blocked_qzeros =
|
||||||
|
new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
|
||||||
|
auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) -
|
||||||
|
new_qzeros.view({Nc, block_n, G, -1});
|
||||||
|
weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, _block_k});
|
||||||
|
at::Tensor compensation = weight_sub_qzero.sum(-1);
|
||||||
|
compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt);
|
||||||
|
int64_t buffer_size_nbytes =
|
||||||
|
_block_k * block_n / 2 + block_n * sizeof(int32_t);
|
||||||
|
blocked_weight = at::empty({Nc, Kc, buffer_size_nbytes}, weight.options());
|
||||||
|
|
||||||
|
auto weight_ptr = weight_reordered.data_ptr<uint8_t>();
|
||||||
|
auto compensation_ptr = compensation.data_ptr<int32_t>();
|
||||||
|
auto blocked_weight_ptr = blocked_weight.data_ptr<uint8_t>();
|
||||||
|
int64_t num_blocks = Nc * Kc;
|
||||||
|
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
|
||||||
|
for (const auto i : c10::irange(begin, end)) {
|
||||||
|
auto in_ptr = weight_ptr + i * _block_k * block_n;
|
||||||
|
auto out_ptr =
|
||||||
|
blocked_weight_ptr + i * block_n * (_block_k / 2 + sizeof(int32_t));
|
||||||
|
int32_t* comp_in_prt = compensation_ptr + i * block_n;
|
||||||
|
int32_t* comp_out_prt =
|
||||||
|
(int32_t*)(void*)(blocked_weight_ptr +
|
||||||
|
i * block_n * (_block_k / 2 + sizeof(int32_t)) +
|
||||||
|
_block_k * block_n / 2);
|
||||||
|
constexpr int n_group_size = 8;
|
||||||
|
constexpr int vnni_size = 4;
|
||||||
|
constexpr int n_group = block_n / n_group_size;
|
||||||
|
for (int nb = 0; nb < n_group; nb += 2) {
|
||||||
|
for (int k = 0; k < _block_k; k += vnni_size) {
|
||||||
|
for (int ni = 0; ni < n_group_size; ++ni) {
|
||||||
|
for (int ki = 0; ki < vnni_size; ++ki) {
|
||||||
|
int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n;
|
||||||
|
int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n;
|
||||||
|
int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size +
|
||||||
|
k * block_n / 2 + ki;
|
||||||
|
uint8_t src_1 = *(in_ptr + src_idx_1);
|
||||||
|
uint8_t src_2 = *(in_ptr + src_idx_2);
|
||||||
|
uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4);
|
||||||
|
*(out_ptr + dst_idx) = dst;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int nb = 0; nb < block_n; nb++) {
|
||||||
|
*(comp_out_prt + nb) = *(comp_in_prt + nb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales),
|
||||||
|
std::move(blocked_qzeros));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> autoawq_to_int4pack(at::Tensor qweight,
|
||||||
|
at::Tensor qzeros) {
|
||||||
|
auto bitshifts = at::tensor({0, 4, 1, 5, 2, 6, 3, 7}, at::kInt) * 4;
|
||||||
|
auto qweight_unsq = qweight.unsqueeze(-1);
|
||||||
|
auto unpacked = at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF;
|
||||||
|
auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte);
|
||||||
|
|
||||||
|
auto qzeros_unsq = qzeros.unsqueeze(-1);
|
||||||
|
auto qzeros_unpacked = at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF;
|
||||||
|
auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte);
|
||||||
|
|
||||||
|
return std::make_tuple(qweight_final, qzeros_final);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
|
||||||
|
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales) {
|
||||||
|
auto res = autoawq_to_int4pack(qweight, qzeros);
|
||||||
|
auto _qweight = std::get<0>(res);
|
||||||
|
auto _qzeros = std::get<1>(res);
|
||||||
|
auto _scales = scales;
|
||||||
|
_qzeros = _qzeros.transpose(-2, -1).contiguous();
|
||||||
|
_scales = _scales.transpose(-2, -1).contiguous();
|
||||||
|
if (_qweight.dim() == 3) {
|
||||||
|
int64_t E = _qweight.size(0);
|
||||||
|
int64_t K = _qweight.size(2);
|
||||||
|
int64_t G = _scales.size(2);
|
||||||
|
int64_t group_size = K / G;
|
||||||
|
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||||
|
int64_t block_n = block_size_n();
|
||||||
|
int64_t Nc = _qweight.size(1) / block_n;
|
||||||
|
int64_t Kc = K / _block_k;
|
||||||
|
int64_t buffer_size_nbytes =
|
||||||
|
_block_k * block_n / 2 + block_n * sizeof(int32_t);
|
||||||
|
auto blocked_weight =
|
||||||
|
at::empty({E, Nc, Kc, buffer_size_nbytes}, _qweight.options());
|
||||||
|
auto blocked_scales =
|
||||||
|
at::empty({E, Nc, G, block_n}, _scales.options()).to(at::kFloat);
|
||||||
|
auto blocked_qzeros =
|
||||||
|
at::empty({E, Nc, G, block_n}, _qzeros.options()).to(at::kChar);
|
||||||
|
for (int i = 0; i < _qweight.size(0); i++) {
|
||||||
|
auto res_ = convert_int4_weight_packed_with_compensation(
|
||||||
|
_qweight[i], _scales[i], _qzeros[i]);
|
||||||
|
blocked_weight[i] = std::get<0>(res_);
|
||||||
|
blocked_scales[i] = std::get<1>(res_);
|
||||||
|
blocked_qzeros[i] = std::get<2>(res_);
|
||||||
|
}
|
||||||
|
_qweight = blocked_weight;
|
||||||
|
_scales = blocked_scales;
|
||||||
|
_qzeros = blocked_qzeros;
|
||||||
|
} else {
|
||||||
|
auto res_ = convert_int4_weight_packed_with_compensation(_qweight, _scales,
|
||||||
|
_qzeros);
|
||||||
|
_qweight = std::get<0>(res_);
|
||||||
|
_scales = std::get<1>(res_);
|
||||||
|
_qzeros = std::get<2>(res_);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(_qweight, _qzeros, _scales);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor int4_scaled_mm_cpu_with_quant(const at::Tensor& input,
|
||||||
|
const at::Tensor& weight,
|
||||||
|
const at::Tensor& weight_scales,
|
||||||
|
const at::Tensor& weight_qzeros,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
at::ScalarType output_dtype) {
|
||||||
|
RECORD_FUNCTION("vllm::int4_scaled_mm_cpu_with_quant",
|
||||||
|
std::vector<c10::IValue>({input, weight}));
|
||||||
|
|
||||||
|
int64_t M_a = input.size(0);
|
||||||
|
int64_t K_a = input.size(1);
|
||||||
|
int64_t lda = input.stride(0);
|
||||||
|
|
||||||
|
const auto st = input.scalar_type();
|
||||||
|
TORCH_CHECK(
|
||||||
|
st == at::kBFloat16 || st == at::kHalf,
|
||||||
|
"int4_scaled_mm_cpu_with_quant: expect A to be bfloat16 or half.");
|
||||||
|
|
||||||
|
constexpr bool sym_quant_act = false;
|
||||||
|
using Tin = typename ActDtype<sym_quant_act>::type;
|
||||||
|
int64_t act_buffer_size =
|
||||||
|
M_a * K_a + M_a * sizeof(float) + M_a * sizeof(int32_t);
|
||||||
|
auto act_buffer =
|
||||||
|
at::empty({act_buffer_size}, input.options().dtype(at::kByte));
|
||||||
|
auto Aq_data = act_buffer.data_ptr<uint8_t>();
|
||||||
|
auto As_data = reinterpret_cast<float*>(Aq_data + M_a * K_a);
|
||||||
|
auto Azp_data = reinterpret_cast<int32_t*>(As_data + M_a);
|
||||||
|
fill_val_stub(Azp_data, 128, M_a);
|
||||||
|
|
||||||
|
auto out_sizes = input.sizes().vec();
|
||||||
|
int64_t N = weight_scales.size(0) * weight_scales.size(-1);
|
||||||
|
out_sizes.back() = N;
|
||||||
|
auto output = at::empty(out_sizes, input.options());
|
||||||
|
int64_t Nc = weight.size(0);
|
||||||
|
int64_t Kc = weight.size(1);
|
||||||
|
int64_t _block_k = K_a / Kc;
|
||||||
|
TORCH_CHECK(N == Nc * BLOCK_N, "DA8W4: weight and input shapes mismatch");
|
||||||
|
int64_t num_groups = weight_scales.size(1);
|
||||||
|
|
||||||
|
const uint8_t* b_ptr = weight.data_ptr<uint8_t>();
|
||||||
|
const float* b_scales_ptr = weight_scales.data_ptr<float>();
|
||||||
|
const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr<int8_t>();
|
||||||
|
const float* bias_ptr =
|
||||||
|
bias.has_value() ? bias.value().data_ptr<float>() : nullptr;
|
||||||
|
int num_threads = at::get_num_threads();
|
||||||
|
int64_t temp_buffer_size = num_threads * BLOCK_M * BLOCK_N * sizeof(float) +
|
||||||
|
num_threads * _block_k * BLOCK_N;
|
||||||
|
auto c_temp_buffer =
|
||||||
|
at::empty({temp_buffer_size}, input.options().dtype(at::kChar));
|
||||||
|
float* c_temp_ptr = (float*)((void*)(c_temp_buffer.data_ptr<int8_t>()));
|
||||||
|
int8_t* dqB_temp_ptr =
|
||||||
|
(int8_t*)((void*)(c_temp_ptr + num_threads * BLOCK_M * BLOCK_N));
|
||||||
|
|
||||||
|
#define LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act) \
|
||||||
|
AT_DISPATCH_FLOATING_TYPES_AND2( \
|
||||||
|
at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, \
|
||||||
|
"int4_scaled_mm_cpu", [&] { \
|
||||||
|
const scalar_t* __restrict__ A_data = input.data_ptr<scalar_t>(); \
|
||||||
|
scalar_t* __restrict__ c_ptr = output.data_ptr<scalar_t>(); \
|
||||||
|
at::parallel_for(0, M_a, 0, [&](int64_t begin, int64_t end) { \
|
||||||
|
for (int64_t m = begin; m < end; ++m) { \
|
||||||
|
quantize_row_int8<scalar_t>(Aq_data + m * K_a, As_data[m], \
|
||||||
|
A_data + m * lda, K_a); \
|
||||||
|
} \
|
||||||
|
}); \
|
||||||
|
_da8w4_linear_impl<sym_quant_act, Tin, scalar_t>( \
|
||||||
|
Aq_data, As_data, Azp_data, b_ptr, b_scales_ptr, b_qzeros_ptr, \
|
||||||
|
bias_ptr, c_ptr, c_temp_ptr, dqB_temp_ptr, M_a, N, K_a, \
|
||||||
|
num_groups); \
|
||||||
|
});
|
||||||
|
|
||||||
|
LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
inline void copy_stub(scalar_t* __restrict__ out,
|
||||||
|
const float* __restrict__ input, int64_t size) {
|
||||||
|
using Vec = at::vec::Vectorized<scalar_t>;
|
||||||
|
using fVec = at::vec::Vectorized<float>;
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||||
|
fVec x0 = fVec::loadu(input + d);
|
||||||
|
fVec x1 = fVec::loadu(input + d + fVec::size());
|
||||||
|
Vec res = convert_from_float_ext<scalar_t>(x0, x1);
|
||||||
|
res.store(out + d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void tinygemm_kernel(scalar_t* C, float* C_temp, const uint8_t* A,
|
||||||
|
const float* scales_a, const int32_t* qzeros_a,
|
||||||
|
const uint8_t* B, const float* scales_b,
|
||||||
|
const int8_t* qzeros_b, const int32_t* compensation,
|
||||||
|
int8_t* dqB_tmp, int64_t M, int64_t K, int64_t lda,
|
||||||
|
int64_t ldc_f, int64_t ldc_s, bool store_out,
|
||||||
|
bool use_brgemm) {
|
||||||
|
_dequant_gemm_accum<false, BLOCK_N, BLOCK_N / 2>(
|
||||||
|
C_temp, A, scales_a, qzeros_a, B, scales_b, qzeros_b, compensation,
|
||||||
|
dqB_tmp, M, K, lda, ldc_f, use_brgemm);
|
||||||
|
if (store_out) {
|
||||||
|
for (int64_t m = 0; m < M; ++m) {
|
||||||
|
copy_stub<scalar_t>(C + m * ldc_s, C_temp + m * ldc_f, BLOCK_N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||||
|
template void tinygemm_kernel<TYPE>( \
|
||||||
|
TYPE * C, float* C_temp, const uint8_t* A, const float* scales_a, \
|
||||||
|
const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, \
|
||||||
|
const int8_t* qzeros_b, const int32_t* compensation, int8_t* dqB_tmp, \
|
||||||
|
int64_t M, int64_t K, int64_t lda, int64_t ldc_f, int64_t ldc_s, \
|
||||||
|
bool store_out, bool use_brgemm)
|
||||||
|
|
||||||
|
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||||
|
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||||
|
|
||||||
|
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
|
||||||
|
at::Tensor& w_scales,
|
||||||
|
std::optional<at::Tensor> bias) {
|
||||||
|
return int4_scaled_mm_cpu_with_quant(x, w, w_scales, w_zeros, bias,
|
||||||
|
x.scalar_type());
|
||||||
|
}
|
||||||
@@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
|||||||
const std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
at::ScalarType out_dtype, bool is_vnni);
|
at::ScalarType out_dtype, bool is_vnni);
|
||||||
|
|
||||||
|
// Adapted from sglang: INT4 W4A8 kernels
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
|
||||||
|
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
|
||||||
|
|
||||||
|
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
|
||||||
|
at::Tensor& w_scales,
|
||||||
|
std::optional<at::Tensor> bias);
|
||||||
|
|
||||||
torch::Tensor get_scheduler_metadata(
|
torch::Tensor get_scheduler_metadata(
|
||||||
const int64_t num_req, const int64_t num_heads_q,
|
const int64_t num_req, const int64_t num_heads_q,
|
||||||
const int64_t num_heads_kv, const int64_t head_dim,
|
const int64_t num_heads_kv, const int64_t head_dim,
|
||||||
@@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
|
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
|
||||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||||
&int8_scaled_mm_with_quant);
|
&int8_scaled_mm_with_quant);
|
||||||
|
|
||||||
|
// Adapted from sglang: INT4 W4A8 kernels
|
||||||
|
ops.def(
|
||||||
|
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
|
||||||
|
"Tensor scales) -> (Tensor, Tensor, Tensor)");
|
||||||
|
ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
|
||||||
|
&convert_weight_packed_scale_zp);
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
|
||||||
|
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
|
||||||
|
ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// CPU attention kernels
|
// CPU attention kernels
|
||||||
|
|||||||
@@ -6,14 +6,16 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Helper function for checking CUTLASS errors
|
* Helper function for checking CUTLASS errors
|
||||||
*/
|
*/
|
||||||
#define CUTLASS_CHECK(status) \
|
#define CUTLASS_CHECK(status) \
|
||||||
{ \
|
{ \
|
||||||
cutlass::Status error = status; \
|
cutlass::Status error = status; \
|
||||||
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||||
cutlassGetStatusString(error)); \
|
cutlassGetStatusString(error)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||||
|
|||||||
@@ -3,6 +3,14 @@
|
|||||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||||
|
|
||||||
|
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
|
||||||
|
// (stable ABI) targets. When compiled under the stable ABI target,
|
||||||
|
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
|
||||||
|
// use torch::stable::Tensor instead.
|
||||||
|
#ifdef TORCH_TARGET_VERSION
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This file defines custom epilogues for fusing channel scales, token scales,
|
This file defines custom epilogues for fusing channel scales, token scales,
|
||||||
bias, and activation zero-points onto a GEMM operation using the
|
bias, and activation zero-points onto a GEMM operation using the
|
||||||
@@ -15,6 +23,12 @@
|
|||||||
|
|
||||||
namespace vllm::c3x {
|
namespace vllm::c3x {
|
||||||
|
|
||||||
|
#ifdef TORCH_TARGET_VERSION
|
||||||
|
using TensorType = torch::stable::Tensor;
|
||||||
|
#else
|
||||||
|
using TensorType = torch::Tensor;
|
||||||
|
#endif
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -84,7 +98,7 @@ struct ScaledEpilogueBase {
|
|||||||
// from a tensor. It can handle both row and column, as well as row/column or
|
// from a tensor. It can handle both row and column, as well as row/column or
|
||||||
// scalar cases.
|
// scalar cases.
|
||||||
template <typename Descriptor, typename T>
|
template <typename Descriptor, typename T>
|
||||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
static auto args_from_tensor(TensorType const& tensor) {
|
||||||
using Arguments = typename Descriptor::Arguments;
|
using Arguments = typename Descriptor::Arguments;
|
||||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||||
@@ -100,7 +114,7 @@ struct ScaledEpilogueBase {
|
|||||||
// This overload handles the case where there might not be a tensor, in which
|
// This overload handles the case where there might not be a tensor, in which
|
||||||
// case a nullptr is passed and a constant (0) is used.
|
// case a nullptr is passed and a constant (0) is used.
|
||||||
template <typename Descriptor, typename T>
|
template <typename Descriptor, typename T>
|
||||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
static auto args_from_tensor(std::optional<TensorType> const& tensor) {
|
||||||
using Arguments = typename Descriptor::Arguments;
|
using Arguments = typename Descriptor::Arguments;
|
||||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||||
@@ -158,8 +172,8 @@ struct ScaledEpilogue
|
|||||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
TensorType const& b_scales) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
|
||||||
@@ -203,9 +217,9 @@ struct ScaledEpilogueBias
|
|||||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||||
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
TensorType const& b_scales,
|
||||||
torch::Tensor const& bias) {
|
TensorType const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias
|
|||||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||||
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
TensorType const& b_scales,
|
||||||
torch::Tensor const& bias) {
|
TensorType const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp
|
|||||||
EVTComputeScaleB, Bias>;
|
EVTComputeScaleB, Bias>;
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
TensorType const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
TensorType const& azp_adj,
|
||||||
std::optional<torch::Tensor> const& bias) {
|
std::optional<TensorType> const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken
|
|||||||
EVTComputeScaleB, Bias>;
|
EVTComputeScaleB, Bias>;
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
TensorType const& b_scales,
|
||||||
torch::Tensor const& azp_adj,
|
TensorType const& azp_adj,
|
||||||
torch::Tensor const& azp,
|
TensorType const& azp,
|
||||||
std::optional<torch::Tensor> const& bias) {
|
std::optional<TensorType> const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -52,7 +54,7 @@ struct ScaledEpilogueBase {
|
|||||||
// from a tensor. It can handle both row and column, as well as row/column or
|
// from a tensor. It can handle both row and column, as well as row/column or
|
||||||
// scalar cases.
|
// scalar cases.
|
||||||
template <typename Descriptor, typename T>
|
template <typename Descriptor, typename T>
|
||||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
static auto args_from_tensor(torch::stable::Tensor const& tensor) {
|
||||||
using Arguments = typename Descriptor::Arguments;
|
using Arguments = typename Descriptor::Arguments;
|
||||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||||
@@ -68,7 +70,8 @@ struct ScaledEpilogueBase {
|
|||||||
// This overload handles the case where there might not be a tensor, in which
|
// This overload handles the case where there might not be a tensor, in which
|
||||||
// case a nullptr is passed and a constant (0) is used.
|
// case a nullptr is passed and a constant (0) is used.
|
||||||
template <typename Descriptor, typename T>
|
template <typename Descriptor, typename T>
|
||||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
static auto args_from_tensor(
|
||||||
|
std::optional<torch::stable::Tensor> const& tensor) {
|
||||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||||
using Arguments = typename Descriptor::Arguments;
|
using Arguments = typename Descriptor::Arguments;
|
||||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||||
@@ -117,8 +120,8 @@ struct ScaledEpilogue
|
|||||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::stable::Tensor const& b_scales) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
|
|
||||||
@@ -160,9 +163,9 @@ struct ScaledEpilogueBias
|
|||||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||||
EVTCompute0, Bias>;
|
EVTCompute0, Bias>;
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& bias) {
|
torch::stable::Tensor const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp
|
|||||||
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::stable::Tensor const& b_scales,
|
||||||
std::optional<torch::Tensor> const& bias) {
|
torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken
|
|||||||
|
|
||||||
using ArgumentType = typename EVTCompute::Arguments;
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
static ArgumentType prepare_args(
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& azp_adj,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& azp,
|
torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
|
||||||
std::optional<torch::Tensor> const& bias) {
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||||
@@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
|
|||||||
torch::stable::Tensor& output_s,
|
torch::stable::Tensor& output_s,
|
||||||
int64_t group_size, double eps, double int8_min,
|
int64_t group_size, double eps, double int8_min,
|
||||||
double int8_max);
|
double int8_max);
|
||||||
|
|
||||||
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||||
|
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||||
|
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
||||||
|
torch::stable::Tensor const& a_tensors,
|
||||||
|
torch::stable::Tensor const& b_tensors,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
torch::stable::Tensor const& expert_offsets,
|
||||||
|
torch::stable::Tensor const& problem_sizes,
|
||||||
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides, bool per_act_token,
|
||||||
|
bool per_out_ch);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_data(
|
||||||
|
const torch::stable::Tensor& topk_ids,
|
||||||
|
torch::stable::Tensor& expert_offsets,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2,
|
||||||
|
torch::stable::Tensor& input_permutation,
|
||||||
|
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||||
|
const int64_t n, const int64_t k,
|
||||||
|
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||||
|
const bool is_gated);
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||||
|
const torch::stable::Tensor& expert_first_token_offset,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||||
|
const bool swap_ab);
|
||||||
|
|
||||||
|
void get_cutlass_batched_moe_mm_data(
|
||||||
|
torch::stable::Tensor& expert_offsets,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2,
|
||||||
|
const torch::stable::Tensor& expert_num_tokens,
|
||||||
|
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||||
|
const int64_t k);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
// clang-format will break include orders
|
// clang-format will break include orders
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/csrc/stable/ops.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
@@ -25,14 +26,14 @@
|
|||||||
namespace vllm::c3x {
|
namespace vllm::c3x {
|
||||||
|
|
||||||
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||||
torch::Tensor const& a, torch::Tensor const& b) {
|
torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
|
||||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
return {m, n, k, 1};
|
return {m, n, k, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GemmKernel>
|
template <typename GemmKernel>
|
||||||
void cutlass_gemm_caller(
|
void cutlass_gemm_caller(
|
||||||
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
|
torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||||
typename GemmKernel::MainloopArguments mainloop_args,
|
typename GemmKernel::MainloopArguments mainloop_args,
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||||
@@ -50,19 +51,20 @@ void cutlass_gemm_caller(
|
|||||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
|
|
||||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
auto const workspace_options =
|
auto workspace =
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
std::nullopt, device);
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
auto stream = get_current_cuda_stream(device.index());
|
||||||
|
|
||||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
CUTLASS_CHECK(status);
|
CUTLASS_CHECK(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Gemm, typename... EpilogueArgs>
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... epilogue_params) {
|
EpilogueArgs&&... epilogue_params) {
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementC = typename Gemm::ElementC;
|
using ElementC = typename Gemm::ElementC;
|
||||||
@@ -4,13 +4,12 @@
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm90_int8(
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
torch::Tensor const& azp_adj,
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
std::optional<torch::Tensor> const& azp,
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
if (azp) {
|
if (azp) {
|
||||||
return cutlass_scaled_mm_sm90_int8_epilogue<
|
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||||
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm100_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales) {
|
||||||
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
|
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
|
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "cutlass/numeric_types.h"
|
#include "cutlass/numeric_types.h"
|
||||||
@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::stable::Tensor const& b_scales) {
|
||||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||||
using GemmKernel = typename Gemm::GemmKernel;
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::stable::Tensor const& b_scales) {
|
||||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||||
|
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm120_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales) {
|
||||||
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
|
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
|
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "cutlass/numeric_types.h"
|
#include "cutlass/numeric_types.h"
|
||||||
@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::stable::Tensor const& b_scales) {
|
||||||
using GemmKernel = typename Gemm::GemmKernel;
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||||
@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
|
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::stable::Tensor const& b_scales) {
|
||||||
int M = a.size(0);
|
int M = a.size(0);
|
||||||
if (M <= 256) {
|
if (M <= 256) {
|
||||||
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
|
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm90_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales) {
|
||||||
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
|
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
|
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "cutlass/numeric_types.h"
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::stable::Tensor const& b_scales) {
|
||||||
using GemmKernel = typename Gemm::GemmKernel;
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||||
@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||||
|
|
||||||
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||||
|
|
||||||
StrideA a_stride;
|
StrideA a_stride;
|
||||||
StrideB b_stride;
|
StrideB b_stride;
|
||||||
@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales) {
|
torch::stable::Tensor const& b_scales) {
|
||||||
// TODO: better heuristics
|
// TODO: better heuristics
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||||
@@ -1,52 +1,57 @@
|
|||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "cutlass_extensions/common.hpp"
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||||
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
void dispatch_scaled_mm(torch::stable::Tensor& c,
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& b,
|
||||||
std::optional<torch::Tensor> const& bias,
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias,
|
||||||
Fp8Func fp8_func, Int8Func int8_func,
|
Fp8Func fp8_func, Int8Func int8_func,
|
||||||
BlockwiseFunc blockwise_func) {
|
BlockwiseFunc blockwise_func) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
|
||||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||||
|
|
||||||
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
||||||
// Standard per-tensor/per-token/per-channel scaling
|
// Standard per-tensor/per-token/per-channel scaling
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
|
||||||
fp8_func(c, a, b, a_scales, b_scales, bias);
|
fp8_func(c, a, b, a_scales, b_scales, bias);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
||||||
int8_func(c, a, b, a_scales, b_scales, bias);
|
int8_func(c, a, b, a_scales, b_scales, bias);
|
||||||
} else {
|
} else {
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
TORCH_CHECK(
|
STD_TORCH_CHECK(
|
||||||
false, "Int8 not supported on SM", version_num,
|
false, "Int8 not supported on SM", version_num,
|
||||||
". Use FP8 quantization instead, or run on older arch (SM < 100).");
|
". Use FP8 quantization instead, or run on older arch (SM < 100).");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
if (version_num >= 90) {
|
if (version_num >= 90) {
|
||||||
TORCH_CHECK(
|
STD_TORCH_CHECK(
|
||||||
a.size(0) == a_scales.size(0) &&
|
a.size(0) == a_scales.size(0) &&
|
||||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
||||||
"a_scale_group_shape must be [1, 128].");
|
"a_scale_group_shape must be [1, 128].");
|
||||||
TORCH_CHECK(
|
STD_TORCH_CHECK(
|
||||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
||||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
||||||
"b_scale_group_shape must be [128, 128].");
|
"b_scale_group_shape must be [128, 128].");
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||||
blockwise_func(c, a, b, a_scales, b_scales);
|
blockwise_func(c, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90_int8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm90_int8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm90_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm100_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm120_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm100_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_blockwise_sm120_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales);
|
||||||
|
} // namespace vllm
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm100_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||||
|
"currently bias dtype must match output dtype ",
|
||||||
|
out.scalar_type());
|
||||||
|
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
|
||||||
|
b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "scaled_mm.cuh"
|
#include "scaled_mm.cuh"
|
||||||
#include "cutlass_gemm_caller.cuh"
|
#include "cutlass_gemm_caller.cuh"
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm, typename... EpilogueArgs>
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... epilogue_params) {
|
EpilogueArgs&&... epilogue_params) {
|
||||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
template <typename InType, typename OutType, bool EnableBias,
|
template <typename InType, typename OutType, bool EnableBias,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm100_fp8_dispatch(
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
EpilogueArgs&&... args) {
|
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
using Cutlass3xGemmDefault =
|
using Cutlass3xGemmDefault =
|
||||||
typename sm100_fp8_config_default<InType, OutType,
|
typename sm100_fp8_config_default<InType, OutType,
|
||||||
@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool EnableBias, typename... EpilogueArgs>
|
template <bool EnableBias, typename... EpilogueArgs>
|
||||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& b_scales,
|
||||||
EpilogueArgs&&... epilogue_args) {
|
EpilogueArgs&&... epilogue_args) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::bfloat16_t, EnableBias>(
|
cutlass::bfloat16_t, EnableBias>(
|
||||||
out, a, b, a_scales, b_scales,
|
out, a, b, a_scales, b_scales,
|
||||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::half_t, EnableBias>(
|
cutlass::half_t, EnableBias>(
|
||||||
out, a, b, a_scales, b_scales,
|
out, a, b, a_scales, b_scales,
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm120_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||||
|
"currently bias dtype must match output dtype ",
|
||||||
|
out.scalar_type());
|
||||||
|
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||||
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "scaled_mm.cuh"
|
#include "scaled_mm.cuh"
|
||||||
#include "cutlass_gemm_caller.cuh"
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue,
|
template <typename, typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... args) {
|
EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
int M = a.size(0);
|
int M = a.size(0);
|
||||||
|
|
||||||
@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
|||||||
|
|
||||||
template <template <typename, typename, typename> typename Epilogue,
|
template <template <typename, typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
|
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... epilogue_args) {
|
EpilogueArgs&&... epilogue_args) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::bfloat16_t, Epilogue>(
|
cutlass::bfloat16_t, Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::half_t, Epilogue>(
|
cutlass::half_t, Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90_fp8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||||
|
"currently bias dtype must match output dtype ",
|
||||||
|
out.scalar_type());
|
||||||
|
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
|
||||||
|
b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "scaled_mm.cuh"
|
#include "scaled_mm.cuh"
|
||||||
#include "cutlass_gemm_caller.cuh"
|
#include "cutlass_gemm_caller.cuh"
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
@@ -235,8 +237,9 @@ struct sm90_fp8_config_M16_N8192 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm, typename... EpilogueArgs>
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_gemm_caller_sm90_fp8(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... epilogue_params) {
|
EpilogueArgs&&... epilogue_params) {
|
||||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
@@ -280,15 +283,15 @@ void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
template <typename InType, typename OutType, bool EnableBias,
|
template <typename InType, typename OutType, bool EnableBias,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm90_fp8_dispatch(
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
EpilogueArgs&&... args) {
|
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
using Cutlass3xGemmDefault =
|
using Cutlass3xGemmDefault =
|
||||||
typename sm90_fp8_config_default<InType, OutType,
|
typename sm90_fp8_config_default<InType, OutType,
|
||||||
@@ -347,22 +350,24 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool EnableBias, typename... EpilogueArgs>
|
template <bool EnableBias, typename... EpilogueArgs>
|
||||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
|
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& b_scales,
|
||||||
EpilogueArgs&&... epilogue_args) {
|
EpilogueArgs&&... epilogue_args) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::bfloat16_t, EnableBias>(
|
cutlass::bfloat16_t, EnableBias>(
|
||||||
out, a, b, a_scales, b_scales,
|
out, a, b, a_scales, b_scales,
|
||||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
cutlass::half_t, EnableBias>(
|
cutlass::half_t, EnableBias>(
|
||||||
out, a, b, a_scales, b_scales,
|
out, a, b, a_scales, b_scales,
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#include "scaled_mm_kernels.hpp"
|
||||||
|
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90_int8(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||||
|
"currently bias dtype must match output dtype ",
|
||||||
|
out.scalar_type());
|
||||||
|
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||||
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "scaled_mm.cuh"
|
#include "scaled_mm.cuh"
|
||||||
#include "cutlass_gemm_caller.cuh"
|
#include "cutlass_gemm_caller.cuh"
|
||||||
|
|
||||||
@@ -87,13 +89,13 @@ struct sm90_int8_config_M32_NSmall {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue,
|
template <typename, typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm90_int8_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... args) {
|
EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using Cutlass3xGemmDefault =
|
using Cutlass3xGemmDefault =
|
||||||
typename sm90_int8_config_default<InType, OutType,
|
typename sm90_int8_config_default<InType, OutType,
|
||||||
@@ -142,19 +144,19 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
|||||||
|
|
||||||
template <template <typename, typename, typename> typename Epilogue,
|
template <template <typename, typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
|
void cutlass_scaled_mm_sm90_int8_epilogue(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... epilogue_args) {
|
EpilogueArgs&&... epilogue_args) {
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||||
Epilogue>(
|
Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
}
|
}
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
|
||||||
#include "core/scalar_type.hpp"
|
|
||||||
#include "cutlass/bfloat16.h"
|
#include "cutlass/bfloat16.h"
|
||||||
#include "cutlass/float8.h"
|
#include "cutlass/float8.h"
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||||
<<<1, num_experts, 0, stream>>>( \
|
<<<1, num_experts, 0, stream>>>( \
|
||||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||||
@@ -51,32 +51,39 @@ __global__ void get_group_gemm_starts(
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void run_get_group_gemm_starts(
|
void run_get_group_gemm_starts(
|
||||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
|
||||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
|
||||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
|
||||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor const& b_scales) {
|
torch::stable::Tensor const& a_scales,
|
||||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
torch::stable::Tensor const& b_scales) {
|
||||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a_tensors.scalar_type() ==
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
STD_TORCH_CHECK(b_tensors.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
// expect int64_t to avoid overflow during offset calculations
|
// expect int64_t to avoid overflow during offset calculations
|
||||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Long);
|
||||||
|
|
||||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
bool per_act_token = a_scales.numel() != 1;
|
bool per_act_token = a_scales.numel() != 1;
|
||||||
bool per_out_ch = b_scales.numel() != num_experts;
|
bool per_out_ch = b_scales.numel() != num_experts;
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||||
|
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
|
||||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
cutlass::bfloat16_t)
|
||||||
|
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -6,6 +6,7 @@
|
|||||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
|
||||||
|
#include <torch/csrc/stable/ops.h>
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
#include "cutlass_extensions/common.hpp"
|
#include "cutlass_extensions/common.hpp"
|
||||||
#include "get_group_starts.cuh"
|
#include "get_group_starts.cuh"
|
||||||
@@ -84,13 +85,17 @@ struct cutlass_3x_group_gemm {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
void cutlass_group_gemm_caller(
|
void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_tensors,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
torch::stable::Tensor const& expert_offsets,
|
||||||
bool per_act_token, bool per_out_ch) {
|
torch::stable::Tensor const& problem_sizes,
|
||||||
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||||
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
@@ -98,16 +103,20 @@ void cutlass_group_gemm_caller(
|
|||||||
|
|
||||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||||
|
|
||||||
auto options_int =
|
auto device = a_tensors.device();
|
||||||
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
|
||||||
|
|
||||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor out_ptrs = torch::stable::empty(
|
||||||
|
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
|
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
|
||||||
|
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
|
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
|
||||||
|
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
|
|
||||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||||
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||||
@@ -156,7 +165,7 @@ void cutlass_group_gemm_caller(
|
|||||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||||
|
|
||||||
int device_id = a_tensors.device().index();
|
int device_id = a_tensors.get_device_index();
|
||||||
static const cutlass::KernelHardwareInfo hw_info{
|
static const cutlass::KernelHardwareInfo hw_info{
|
||||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||||
device_id)};
|
device_id)};
|
||||||
@@ -170,9 +179,9 @@ void cutlass_group_gemm_caller(
|
|||||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
|
|
||||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
auto const workspace_options =
|
auto workspace =
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
std::nullopt, device);
|
||||||
|
|
||||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
CUTLASS_CHECK(status);
|
CUTLASS_CHECK(status);
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
#include <cudaTypedefs.h>
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "grouped_mm_c3x.cuh"
|
#include "grouped_mm_c3x.cuh"
|
||||||
@@ -62,21 +63,27 @@ struct sm100_fp8_config_N8192 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename InType, typename OutType>
|
template <typename InType, typename OutType>
|
||||||
void run_cutlass_moe_mm_sm100(
|
void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_tensors,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
torch::stable::Tensor const& expert_offsets,
|
||||||
bool per_act_token, bool per_out_ch) {
|
torch::stable::Tensor const& problem_sizes,
|
||||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
torch::stable::Tensor const& a_strides,
|
||||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
torch::stable::Tensor const& b_strides,
|
||||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
|
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||||
|
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||||
|
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||||
|
|
||||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
STD_TORCH_CHECK(
|
||||||
"A tensors must be of type float8_e4m3fn.");
|
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
"A tensors must be of type float8_e4m3fn.");
|
||||||
"B tensors must be of type float8_e4m3fn.");
|
STD_TORCH_CHECK(
|
||||||
|
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||||
|
"B tensors must be of type float8_e4m3fn.");
|
||||||
|
|
||||||
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
|
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
|
||||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
@@ -107,14 +114,18 @@ void run_cutlass_moe_mm_sm100(
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void dispatch_moe_mm_sm100(
|
void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_tensors,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
torch::stable::Tensor const& expert_offsets,
|
||||||
bool per_act_token, bool per_out_ch) {
|
torch::stable::Tensor const& problem_sizes,
|
||||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
|
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
@@ -127,13 +138,17 @@ void dispatch_moe_mm_sm100(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_moe_mm_sm100(
|
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_tensors,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
torch::stable::Tensor const& expert_offsets,
|
||||||
bool per_act_token, bool per_out_ch) {
|
torch::stable::Tensor const& problem_sizes,
|
||||||
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
c_strides, per_act_token, per_out_ch);
|
c_strides, per_act_token, per_out_ch);
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
#include <cudaTypedefs.h>
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "grouped_mm_c3x.cuh"
|
#include "grouped_mm_c3x.cuh"
|
||||||
@@ -103,21 +104,27 @@ struct sm90_fp8_config_N8192 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename InType, typename OutType>
|
template <typename InType, typename OutType>
|
||||||
void run_cutlass_moe_mm_sm90(
|
void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_tensors,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
torch::stable::Tensor const& expert_offsets,
|
||||||
bool per_act_token, bool per_out_ch) {
|
torch::stable::Tensor const& problem_sizes,
|
||||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
torch::stable::Tensor const& a_strides,
|
||||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
torch::stable::Tensor const& b_strides,
|
||||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
|
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||||
|
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||||
|
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||||
|
|
||||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
STD_TORCH_CHECK(
|
||||||
"A tensors must be of type float8_e4m3fn.");
|
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
"A tensors must be of type float8_e4m3fn.");
|
||||||
"B tensors must be of type float8_e4m3fn.");
|
STD_TORCH_CHECK(
|
||||||
|
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||||
|
"B tensors must be of type float8_e4m3fn.");
|
||||||
|
|
||||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
@@ -163,14 +170,18 @@ void run_cutlass_moe_mm_sm90(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void dispatch_moe_mm_sm90(
|
void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_tensors,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
torch::stable::Tensor const& expert_offsets,
|
||||||
bool per_act_token, bool per_out_ch) {
|
torch::stable::Tensor const& problem_sizes,
|
||||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
|
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
@@ -185,13 +196,17 @@ void dispatch_moe_mm_sm90(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void cutlass_moe_mm_sm90(
|
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b_tensors,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::stable::Tensor const& a_scales,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::stable::Tensor const& b_scales,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
torch::stable::Tensor const& expert_offsets,
|
||||||
bool per_act_token, bool per_out_ch) {
|
torch::stable::Tensor const& problem_sizes,
|
||||||
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
c_strides, per_act_token, per_out_ch);
|
c_strides, per_act_token, per_out_ch);
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
#include <cudaTypedefs.h>
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/csrc/stable/ops.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
|
||||||
#include "dispatch_utils.h"
|
#include "libtorch_stable/dispatch_utils.h"
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
@@ -110,19 +112,22 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
inline void launch_compute_problem_sizes(
|
inline void launch_compute_problem_sizes(const torch::stable::Tensor& topk_ids,
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
torch::stable::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
|
torch::stable::Tensor& problem_sizes2,
|
||||||
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
|
torch::stable::Tensor& atomic_buffer,
|
||||||
const bool swap_ab, const bool is_gated) {
|
int64_t num_experts, int64_t n,
|
||||||
|
int64_t k, cudaStream_t stream,
|
||||||
|
const bool swap_ab,
|
||||||
|
const bool is_gated) {
|
||||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||||
|
|
||||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
auto const* topk_ptr = topk_ids.const_data_ptr<int32_t>();
|
||||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||||
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
|
auto* atomic_ptr = atomic_buffer.mutable_data_ptr<int32_t>();
|
||||||
|
|
||||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||||
@@ -171,46 +176,53 @@ __global__ void compute_problem_sizes_from_expert_offsets(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||||
const torch::Tensor& expert_first_token_offset,
|
const torch::stable::Tensor& expert_first_token_offset,
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
torch::stable::Tensor& problem_sizes1,
|
||||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||||
TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
const bool swap_ab) {
|
||||||
"expert_first_token_offset must be a CUDA tensor");
|
STD_TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||||
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
|
"expert_first_token_offset must be a CUDA tensor");
|
||||||
"expert_first_token_offset must be int64");
|
STD_TORCH_CHECK(expert_first_token_offset.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Long,
|
||||||
|
"expert_first_token_offset must be int64");
|
||||||
|
|
||||||
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
STD_TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||||
"problem_sizes must be CUDA tensors");
|
"problem_sizes must be CUDA tensors");
|
||||||
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
|
STD_TORCH_CHECK(
|
||||||
problem_sizes2.dtype() == torch::kInt32,
|
problem_sizes1.scalar_type() == torch::headeronly::ScalarType::Int &&
|
||||||
"problem_sizes must be int32");
|
problem_sizes2.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||||
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
"problem_sizes must be int32");
|
||||||
"problem_sizes must be contiguous");
|
STD_TORCH_CHECK(
|
||||||
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||||
"problem_sizes must be 2D tensors");
|
"problem_sizes must be contiguous");
|
||||||
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
STD_TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||||
"problem_sizes second dim must be 3");
|
"problem_sizes must be 2D tensors");
|
||||||
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
|
STD_TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
"problem_sizes second dim must be 3");
|
||||||
|
STD_TORCH_CHECK(problem_sizes1.size(0) == problem_sizes2.size(0) &&
|
||||||
|
problem_sizes1.size(1) == problem_sizes2.size(1),
|
||||||
|
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||||
|
|
||||||
int64_t const num_experts64 = problem_sizes1.size(0);
|
int64_t const num_experts64 = problem_sizes1.size(0);
|
||||||
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
|
STD_TORCH_CHECK(
|
||||||
"expert_first_token_offset must have num_experts + 1 elements");
|
expert_first_token_offset.numel() == num_experts64 + 1,
|
||||||
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
"expert_first_token_offset must have num_experts + 1 elements");
|
||||||
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
|
STD_TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||||
|
STD_TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX,
|
||||||
|
"n and k must fit in int32");
|
||||||
|
|
||||||
int const num_experts = static_cast<int>(num_experts64);
|
int const num_experts = static_cast<int>(num_experts64);
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(
|
auto stream =
|
||||||
expert_first_token_offset.device().index());
|
get_current_cuda_stream(expert_first_token_offset.get_device_index());
|
||||||
|
|
||||||
int const threads = (num_experts < 256) ? num_experts : 256;
|
int const threads = (num_experts < 256) ? num_experts : 256;
|
||||||
int const blocks = (num_experts + threads - 1) / threads;
|
int const blocks = (num_experts + threads - 1) / threads;
|
||||||
|
|
||||||
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
|
auto const* offsets_ptr = expert_first_token_offset.const_data_ptr<int64_t>();
|
||||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||||
|
|
||||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||||
compute_problem_sizes_from_expert_offsets<SwapAB>
|
compute_problem_sizes_from_expert_offsets<SwapAB>
|
||||||
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
||||||
num_experts, static_cast<int>(n),
|
num_experts, static_cast<int>(n),
|
||||||
@@ -219,16 +231,19 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void get_cutlass_moe_mm_data_caller(
|
void get_cutlass_moe_mm_data_caller(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
const torch::stable::Tensor& topk_ids,
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
torch::stable::Tensor& expert_offsets,
|
||||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
torch::stable::Tensor& problem_sizes1,
|
||||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
torch::stable::Tensor& problem_sizes2,
|
||||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
torch::stable::Tensor& input_permutation,
|
||||||
|
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||||
|
const int64_t n, const int64_t k,
|
||||||
|
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||||
const bool is_gated) {
|
const bool is_gated) {
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
auto device = topk_ids.device();
|
||||||
auto options_int32 =
|
auto stream = get_current_cuda_stream(device.index());
|
||||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
torch::stable::Tensor atomic_buffer = torch::stable::new_zeros(
|
||||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
topk_ids, {num_experts}, torch::headeronly::ScalarType::Int);
|
||||||
|
|
||||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||||
|
|
||||||
@@ -290,11 +305,13 @@ __global__ void compute_batched_moe_data(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void get_cutlass_batched_moe_mm_data_caller(
|
void get_cutlass_batched_moe_mm_data_caller(
|
||||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
torch::stable::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2,
|
||||||
|
const torch::stable::Tensor& expert_num_tokens,
|
||||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||||
const int64_t k) {
|
const int64_t k) {
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
auto stream = get_current_cuda_stream(expert_offsets.get_device_index());
|
||||||
|
|
||||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||||
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
|
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||||
@@ -311,4 +328,4 @@ void get_cutlass_batched_moe_mm_data_caller(
|
|||||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||||
k);
|
k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
220
csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu
Normal file
220
csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
#include <stddef.h>
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "scaled_mm_c2x.cuh"
|
||||||
|
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||||
|
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||||
|
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||||
|
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||||
|
|
||||||
|
#include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||||
|
|
||||||
|
using namespace vllm;
|
||||||
|
|
||||||
|
/*
|
||||||
|
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||||
|
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <template <typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_args) {
|
||||||
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
|
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
|
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||||
|
"currently bias dtype must match output dtype ",
|
||||||
|
out.scalar_type());
|
||||||
|
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||||
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm75(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
|
||||||
|
if (azp) {
|
||||||
|
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <template <typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_args) {
|
||||||
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
|
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
|
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||||
|
"currently bias dtype must match output dtype ",
|
||||||
|
out.scalar_type());
|
||||||
|
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||||
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm80(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
|
||||||
|
if (azp) {
|
||||||
|
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <template <typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
EpilogueArgs&&... epilogue_args) {
|
||||||
|
if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
|
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||||
|
Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
|
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
|
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
|
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::bfloat16_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||||
|
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::half_t, Epilogue>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||||
|
"currently bias dtype must match output dtype ",
|
||||||
|
out.scalar_type());
|
||||||
|
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||||
|
out, a, b, a_scales, b_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm89(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
|
||||||
|
if (azp) {
|
||||||
|
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||||
|
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/csrc/stable/ops.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
|
||||||
// clang-format will break include orders
|
// clang-format will break include orders
|
||||||
// clang-format off
|
// clang-format off
|
||||||
@@ -95,8 +96,9 @@ struct cutlass_2x_gemm {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm, typename... EpilogueArgs>
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
inline void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... epilogue_params) {
|
EpilogueArgs&&... epilogue_params) {
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
@@ -149,11 +151,12 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
// Launch the CUTLASS GEMM kernel.
|
// Launch the CUTLASS GEMM kernel.
|
||||||
typename Gemm::Op gemm_op;
|
typename Gemm::Op gemm_op;
|
||||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
auto const workspace_options =
|
auto device = a.device();
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
auto workspace =
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||||
|
std::nullopt, device);
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
auto stream = get_current_cuda_stream(device.index());
|
||||||
|
|
||||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
||||||
@@ -161,9 +164,9 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||||
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
inline void fallback_cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... args) {
|
EpilogueArgs&&... args) {
|
||||||
// In some cases, the GPU isn't able to accommodate the
|
// In some cases, the GPU isn't able to accommodate the
|
||||||
// shared memory requirements of the Gemm. In such cases, use
|
// shared memory requirements of the Gemm. In such cases, use
|
||||||
@@ -180,8 +183,8 @@ inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
|||||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||||
std::forward<EpilogueArgs>(args)...);
|
std::forward<EpilogueArgs>(args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
STD_TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||||
max_shared_mem_per_block_opt_in);
|
max_shared_mem_per_block_opt_in);
|
||||||
return cutlass_gemm_caller<FallbackGemm>(
|
return cutlass_gemm_caller<FallbackGemm>(
|
||||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
}
|
}
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "scaled_mm_c2x.cuh"
|
#include "scaled_mm_c2x.cuh"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -70,13 +72,13 @@ struct sm75_config_M32 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm75_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... args) {
|
EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using Cutlass2xGemmDefault =
|
using Cutlass2xGemmDefault =
|
||||||
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "scaled_mm_c2x.cuh"
|
#include "scaled_mm_c2x.cuh"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -72,13 +74,13 @@ struct sm80_config_M16 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm80_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... args) {
|
EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using Cutlass2xGemmDefault =
|
using Cutlass2xGemmDefault =
|
||||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "scaled_mm_c2x.cuh"
|
#include "scaled_mm_c2x.cuh"
|
||||||
#include "cutlass/float8.h"
|
#include "cutlass/float8.h"
|
||||||
|
|
||||||
@@ -34,10 +36,12 @@ struct sm89_fp8_config_default {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||||
@@ -84,10 +88,12 @@ struct sm89_fp8_config_M256 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||||
@@ -125,10 +131,12 @@ struct sm89_fp8_config_M128 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||||
@@ -173,10 +181,12 @@ struct sm89_fp8_config_M64 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||||
@@ -227,10 +237,12 @@ struct sm89_fp8_config_M32 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||||
@@ -280,10 +292,12 @@ struct sm89_fp8_config_M16 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||||
@@ -326,13 +340,15 @@ struct sm89_fp8_config_M16 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm89_fp8_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... args) {
|
EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(a.scalar_type() ==
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
STD_TORCH_CHECK(b.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
|
|
||||||
uint32_t const m = a.size(0);
|
uint32_t const m = a.size(0);
|
||||||
uint32_t const mp2 =
|
uint32_t const mp2 =
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include "scaled_mm_c2x.cuh"
|
#include "scaled_mm_c2x.cuh"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -32,10 +34,11 @@ struct sm89_int8_config_default {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||||
@@ -88,10 +91,11 @@ struct sm89_int8_config_M256 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||||
@@ -143,10 +147,11 @@ struct sm89_int8_config_M128 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||||
@@ -193,10 +198,11 @@ struct sm89_int8_config_M64 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||||
@@ -234,10 +240,11 @@ struct sm89_int8_config_M32 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||||
@@ -276,10 +283,11 @@ struct sm89_int8_config_M16 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
static void dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
using FallbackGemm =
|
using FallbackGemm =
|
||||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||||
@@ -311,13 +319,13 @@ struct sm89_int8_config_M16 {
|
|||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename> typename Epilogue,
|
template <typename, typename> typename Epilogue,
|
||||||
typename... EpilogueArgs>
|
typename... EpilogueArgs>
|
||||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
inline void cutlass_gemm_sm89_int8_dispatch(torch::stable::Tensor& out,
|
||||||
torch::Tensor const& a,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& b,
|
||||||
EpilogueArgs&&... args) {
|
EpilogueArgs&&... args) {
|
||||||
static_assert(std::is_same<InType, int8_t>());
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||||
|
|
||||||
uint32_t const m = a.size(0);
|
uint32_t const m = a.size(0);
|
||||||
uint32_t const mp2 =
|
uint32_t const mp2 =
|
||||||
@@ -8,11 +8,12 @@
|
|||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
std::optional<torch::Tensor> const& bias) {
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||||
vllm::cutlass_scaled_mm_sm100_fp8,
|
vllm::cutlass_scaled_mm_sm100_fp8,
|
||||||
nullptr, // int8 not supported on SM100
|
nullptr, // int8 not supported on SM100
|
||||||
@@ -8,11 +8,12 @@
|
|||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
|
||||||
torch::Tensor const& b,
|
torch::stable::Tensor const& a,
|
||||||
torch::Tensor const& a_scales,
|
torch::stable::Tensor const& b,
|
||||||
torch::Tensor const& b_scales,
|
torch::stable::Tensor const& a_scales,
|
||||||
std::optional<torch::Tensor> const& bias) {
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||||
vllm::cutlass_scaled_mm_sm120_fp8,
|
vllm::cutlass_scaled_mm_sm120_fp8,
|
||||||
nullptr, // int8 not supported on SM120
|
nullptr, // int8 not supported on SM120
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
#include "c3x/scaled_mm_helper.hpp"
|
||||||
|
#include "c3x/scaled_mm_kernels.hpp"
|
||||||
|
|
||||||
|
/*
|
||||||
|
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||||
|
NVIDIA GPUs with sm90a (Hopper).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||||
|
vllm::cutlass_scaled_mm_sm90_fp8,
|
||||||
|
vllm::cutlass_scaled_mm_sm90_int8,
|
||||||
|
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm90(
|
||||||
|
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
|
||||||
|
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||||
|
azp, bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,451 @@
|
|||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
|
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
#endif
|
||||||
|
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||||
|
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||||
|
torch::stable::Tensor const& a_tensors,
|
||||||
|
torch::stable::Tensor const& b_tensors,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
torch::stable::Tensor const& expert_offsets,
|
||||||
|
torch::stable::Tensor const& problem_sizes,
|
||||||
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||||
|
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||||
|
torch::stable::Tensor const& a_tensors,
|
||||||
|
torch::stable::Tensor const& b_tensors,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
torch::stable::Tensor const& expert_offsets,
|
||||||
|
torch::stable::Tensor const& problem_sizes,
|
||||||
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||||
|
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||||
|
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
|
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
|
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
||||||
|
void get_cutlass_moe_mm_data_caller(
|
||||||
|
const torch::stable::Tensor& topk_ids,
|
||||||
|
torch::stable::Tensor& expert_offsets,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2,
|
||||||
|
torch::stable::Tensor& input_permutation,
|
||||||
|
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||||
|
const int64_t n, const int64_t k,
|
||||||
|
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||||
|
const bool is_gated);
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||||
|
const torch::stable::Tensor& expert_first_token_offset,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||||
|
const bool swap_ab);
|
||||||
|
|
||||||
|
void get_cutlass_batched_moe_mm_data_caller(
|
||||||
|
torch::stable::Tensor& expert_offsets,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2,
|
||||||
|
const torch::stable::Tensor& expert_num_tokens,
|
||||||
|
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||||
|
const int64_t k);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm75(
|
||||||
|
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm80(
|
||||||
|
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp_sm89(
|
||||||
|
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
|
void cutlass_scaled_mm_azp_sm90(
|
||||||
|
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||||
|
// CUTLASS FP8 kernels need at least
|
||||||
|
// CUDA 12.0 on SM90 systems (Hopper)
|
||||||
|
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION
|
||||||
|
if (cuda_device_capability >= 90) {
|
||||||
|
return CUDA_VERSION >= 12000;
|
||||||
|
} else if (cuda_device_capability >= 89) {
|
||||||
|
return CUDA_VERSION >= 12040;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||||
|
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||||
|
// and at least SM90 (Hopper)
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION
|
||||||
|
if (cuda_device_capability >= 100) {
|
||||||
|
return CUDA_VERSION >= 12080;
|
||||||
|
} else if (cuda_device_capability >= 90) {
|
||||||
|
return CUDA_VERSION >= 12000;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||||
|
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
||||||
|
// or CUDA 12.8 and SM100 (Blackwell)
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION
|
||||||
|
if (cuda_device_capability >= 100) {
|
||||||
|
return CUDA_VERSION >= 12080;
|
||||||
|
}
|
||||||
|
if (cuda_device_capability >= 90) {
|
||||||
|
return CUDA_VERSION >= 12030;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
// Checks for conformality
|
||||||
|
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
|
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
|
b.size(1) == c.size(1));
|
||||||
|
|
||||||
|
// Check for strides and alignment
|
||||||
|
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
|
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
|
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
|
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||||
|
bias->dim() == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
|
a.get_device_index());
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||||
|
if (version_num >= 120) {
|
||||||
|
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||||
|
if (version_num >= 100 && version_num < 120) {
|
||||||
|
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Guard against compilation issues for sm90 kernels
|
||||||
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
|
if (version_num >= 90 && version_num < 100) {
|
||||||
|
// Hopper
|
||||||
|
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||||
|
if (version_num == 89) {
|
||||||
|
// Ada Lovelace
|
||||||
|
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (version_num >= 80) {
|
||||||
|
// Ampere
|
||||||
|
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (version_num >= 75) {
|
||||||
|
// Turing
|
||||||
|
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||||
|
"CUDA device capability: ",
|
||||||
|
version_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
||||||
|
torch::stable::Tensor const& a_tensors,
|
||||||
|
torch::stable::Tensor const& b_tensors,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
torch::stable::Tensor const& expert_offsets,
|
||||||
|
torch::stable::Tensor const& problem_sizes,
|
||||||
|
torch::stable::Tensor const& a_strides,
|
||||||
|
torch::stable::Tensor const& b_strides,
|
||||||
|
torch::stable::Tensor const& c_strides, bool per_act_token,
|
||||||
|
bool per_out_ch) {
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||||
|
if (version_num >= 100 && version_num < 110) {
|
||||||
|
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, per_act_token, per_out_ch);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||||
|
if (version_num >= 90 && version_num < 100) {
|
||||||
|
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides, per_act_token, per_out_ch);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||||
|
". Required capability: 90 or 100");
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_data(
|
||||||
|
const torch::stable::Tensor& topk_ids,
|
||||||
|
torch::stable::Tensor& expert_offsets,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2,
|
||||||
|
torch::stable::Tensor& input_permutation,
|
||||||
|
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||||
|
const int64_t n, const int64_t k,
|
||||||
|
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||||
|
const bool is_gated) {
|
||||||
|
// This function currently gets compiled only if we have a valid cutlass moe
|
||||||
|
// mm to run it for.
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||||
|
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||||
|
problem_sizes2, input_permutation,
|
||||||
|
output_permutation, num_experts, n, k,
|
||||||
|
blockscale_offsets, is_gated);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||||
|
"CUDA device capability: ",
|
||||||
|
version_num, ". Required capability: 90, 100, or 120");
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||||
|
const torch::stable::Tensor& expert_first_token_offset,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||||
|
const bool swap_ab) {
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||||
|
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||||
|
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
||||||
|
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
||||||
|
version_num, ". Required capability: 90, 100, or 120");
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_cutlass_batched_moe_mm_data(
|
||||||
|
torch::stable::Tensor& expert_offsets,
|
||||||
|
torch::stable::Tensor& problem_sizes1,
|
||||||
|
torch::stable::Tensor& problem_sizes2,
|
||||||
|
const torch::stable::Tensor& expert_num_tokens,
|
||||||
|
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||||
|
const int64_t k) {
|
||||||
|
// This function currently gets compiled only if we have a valid cutlass moe
|
||||||
|
// mm to run it for.
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||||
|
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||||
|
problem_sizes2, expert_num_tokens,
|
||||||
|
num_local_experts, padded_m, n, k);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled get_cutlass_batched_moe_mm_data: no "
|
||||||
|
"cutlass_scaled_mm kernel "
|
||||||
|
"for CUDA device capability: ",
|
||||||
|
version_num, ". Required capability: 90, 100, or 120");
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
|
||||||
|
torch::stable::Tensor const& a,
|
||||||
|
torch::stable::Tensor const& b,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
torch::stable::Tensor const& azp_adj,
|
||||||
|
std::optional<torch::stable::Tensor> const& azp,
|
||||||
|
std::optional<torch::stable::Tensor> const& bias) {
|
||||||
|
// Checks for conformality
|
||||||
|
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
|
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
|
b.size(1) == c.size(1));
|
||||||
|
STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||||
|
STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||||
|
|
||||||
|
// Check for strides and alignment
|
||||||
|
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
|
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
|
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
|
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
|
||||||
|
// bias, azp, azp_adj are all 1d
|
||||||
|
// bias and azp_adj have n elements, azp has m elements
|
||||||
|
if (bias) {
|
||||||
|
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||||
|
}
|
||||||
|
if (azp) {
|
||||||
|
STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||||
|
}
|
||||||
|
STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||||
|
|
||||||
|
// azp & bias types
|
||||||
|
STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||||
|
STD_TORCH_CHECK(!azp ||
|
||||||
|
azp->scalar_type() == torch::headeronly::ScalarType::Int);
|
||||||
|
STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
|
||||||
|
"currently bias dtype must match output dtype ",
|
||||||
|
c.scalar_type());
|
||||||
|
|
||||||
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
|
a.get_device_index());
|
||||||
|
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||||
|
if (version_num >= 90) {
|
||||||
|
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||||
|
if (version_num == 89) {
|
||||||
|
// Ada Lovelace
|
||||||
|
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (version_num >= 80) {
|
||||||
|
// Ampere
|
||||||
|
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Turing
|
||||||
|
STD_TORCH_CHECK(version_num >= 75);
|
||||||
|
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||||
|
"CUDA device capability: ",
|
||||||
|
version_num);
|
||||||
|
}
|
||||||
@@ -31,6 +31,78 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
|||||||
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
||||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||||
"()");
|
"()");
|
||||||
|
|
||||||
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
|
// quantization, as well as bias
|
||||||
|
ops.def(
|
||||||
|
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||||
|
" Tensor b, Tensor a_scales,"
|
||||||
|
" Tensor b_scales, Tensor? bias) -> ()");
|
||||||
|
|
||||||
|
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||||
|
// quantization.
|
||||||
|
ops.def(
|
||||||
|
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||||
|
" Tensor b, Tensor a_scales,"
|
||||||
|
" Tensor b_scales, Tensor azp_adj,"
|
||||||
|
" Tensor? azp, Tensor? bias) -> ()");
|
||||||
|
|
||||||
|
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||||
|
// capability
|
||||||
|
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||||
|
|
||||||
|
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
||||||
|
// capability
|
||||||
|
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
||||||
|
|
||||||
|
// CUTLASS w8a8 grouped GEMM
|
||||||
|
ops.def(
|
||||||
|
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||||
|
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||||
|
" Tensor problem_sizes, Tensor a_strides, "
|
||||||
|
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
||||||
|
" bool per_out_ch) -> ()");
|
||||||
|
|
||||||
|
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||||
|
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
||||||
|
// (token start indices of each expert). In addition to this, it computes
|
||||||
|
// problem sizes for each expert's multiplication used by the two mms called
|
||||||
|
// from fused MoE operation, and arrays with permutations required to shuffle
|
||||||
|
// and de-shuffle the input/output of the fused operation.
|
||||||
|
ops.def(
|
||||||
|
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||||
|
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||||
|
" Tensor! input_permutation, "
|
||||||
|
" Tensor! output_permutation, int num_experts, "
|
||||||
|
" int n, int k, Tensor? blockscale_offsets, "
|
||||||
|
" bool is_gated) -> ()");
|
||||||
|
|
||||||
|
// compute per-expert problem sizes from expert_first_token_offset
|
||||||
|
// produced by vLLM's moe_permute kernel
|
||||||
|
ops.def(
|
||||||
|
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
||||||
|
" Tensor expert_first_token_offset, "
|
||||||
|
" Tensor! problem_sizes1, "
|
||||||
|
" Tensor! problem_sizes2, "
|
||||||
|
" int n, int k, bool swap_ab) -> ()");
|
||||||
|
|
||||||
|
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||||
|
// GEMM in batched expert format. It takes expert_num_tokens
|
||||||
|
// as an input, and computes expert_offsets (token start indices of each
|
||||||
|
// expert). In addition to this, it computes problem sizes for each expert's
|
||||||
|
// multiplication used by the two mms called from fused MoE operation.
|
||||||
|
ops.def(
|
||||||
|
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
|
||||||
|
" Tensor! problem_sizes1, "
|
||||||
|
" Tensor! problem_sizes2, "
|
||||||
|
" Tensor expert_num_tokens, "
|
||||||
|
" int num_local_experts, int padded_m, "
|
||||||
|
" int n, int k) -> ()");
|
||||||
|
|
||||||
|
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||||
|
ops.def(
|
||||||
|
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||||
|
"bool");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,6 +118,31 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
|||||||
TORCH_BOX(&per_token_group_quant_8bit_packed));
|
TORCH_BOX(&per_token_group_quant_8bit_packed));
|
||||||
ops.impl("per_token_group_quant_int8",
|
ops.impl("per_token_group_quant_int8",
|
||||||
TORCH_BOX(&per_token_group_quant_int8));
|
TORCH_BOX(&per_token_group_quant_int8));
|
||||||
|
|
||||||
|
// CUTLASS scaled_mm ops
|
||||||
|
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
|
||||||
|
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
|
||||||
|
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
|
||||||
|
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
|
||||||
|
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
|
||||||
|
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
|
||||||
|
ops.impl("get_cutlass_batched_moe_mm_data",
|
||||||
|
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// These capability-check functions take only primitive args (no tensors), so
|
||||||
|
// there is no device to dispatch on. CompositeExplicitAutograd makes them
|
||||||
|
// available for all backends. This is the stable ABI equivalent of calling
|
||||||
|
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
|
||||||
|
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
ops.impl("cutlass_scaled_mm_supports_fp8",
|
||||||
|
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
|
||||||
|
ops.impl("cutlass_group_gemm_supported",
|
||||||
|
TORCH_BOX(&cutlass_group_gemm_supported));
|
||||||
|
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||||
|
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||||
|
#include <torch/csrc/stable/accelerator.h>
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
#include <torch/headeronly/util/shim_utils.h>
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
|
||||||
|
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
|
||||||
|
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
|
||||||
|
|
||||||
// Utility to get the current CUDA stream for a given device using stable APIs.
|
// Utility to get the current CUDA stream for a given device using stable APIs.
|
||||||
// Returns a cudaStream_t for use in kernel launches.
|
// Returns a cudaStream_t for use in kernel launches.
|
||||||
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
|
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ struct SSMParamsBase {
|
|||||||
int dim_ngroups_ratio;
|
int dim_ngroups_ratio;
|
||||||
bool is_variable_B;
|
bool is_variable_B;
|
||||||
bool is_variable_C;
|
bool is_variable_C;
|
||||||
int64_t pad_slot_id;
|
int64_t null_block_id;
|
||||||
|
|
||||||
bool delta_softplus;
|
bool delta_softplus;
|
||||||
bool cache_enabled;
|
bool cache_enabled;
|
||||||
|
|||||||
@@ -118,9 +118,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|||||||
|
|
||||||
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
int cache_index;
|
||||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
if (cache_indices == nullptr) {
|
||||||
if (cache_index == params.pad_slot_id){
|
cache_index = batch_id;
|
||||||
|
} else if (params.cache_enabled) {
|
||||||
|
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
|
||||||
|
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
|
||||||
|
} else {
|
||||||
|
cache_index = cache_indices[batch_id];
|
||||||
|
}
|
||||||
|
// Skip batch entries whose cache index maps to the null block (padding).
|
||||||
|
if (cache_indices != nullptr && cache_index == params.null_block_id){
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
||||||
@@ -527,7 +535,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
const std::optional<at::Tensor>& cache_indices,
|
const std::optional<at::Tensor>& cache_indices,
|
||||||
const std::optional<at::Tensor>& has_initial_state,
|
const std::optional<at::Tensor>& has_initial_state,
|
||||||
bool varlen,
|
bool varlen,
|
||||||
int64_t pad_slot_id,
|
int64_t null_block_id,
|
||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||||
@@ -544,7 +552,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|||||||
params.dstate = dstate;
|
params.dstate = dstate;
|
||||||
params.n_groups = n_groups;
|
params.n_groups = n_groups;
|
||||||
params.dim_ngroups_ratio = dim / n_groups;
|
params.dim_ngroups_ratio = dim / n_groups;
|
||||||
params.pad_slot_id = pad_slot_id;
|
params.null_block_id = null_block_id;
|
||||||
|
|
||||||
params.delta_softplus = delta_softplus;
|
params.delta_softplus = delta_softplus;
|
||||||
|
|
||||||
@@ -658,7 +666,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
const torch::Tensor &ssm_states,
|
const torch::Tensor &ssm_states,
|
||||||
// used to identify padding entries if cache_indices provided
|
// used to identify padding entries if cache_indices provided
|
||||||
// in case of padding, the kernel will return early
|
// in case of padding, the kernel will return early
|
||||||
int64_t pad_slot_id,
|
int64_t null_block_id,
|
||||||
int64_t block_size,
|
int64_t block_size,
|
||||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||||
@@ -805,7 +813,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|||||||
cache_indices,
|
cache_indices,
|
||||||
has_initial_state,
|
has_initial_state,
|
||||||
varlen,
|
varlen,
|
||||||
pad_slot_id,
|
null_block_id,
|
||||||
block_size,
|
block_size,
|
||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
|
|||||||
47
csrc/ops.h
47
csrc/ops.h
@@ -228,63 +228,18 @@ int64_t ggml_moe_get_block_size(int64_t type);
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
|
||||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
|
||||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
|
||||||
|
|
||||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||||
torch::Tensor const& B_sf,
|
torch::Tensor const& B_sf,
|
||||||
torch::Tensor const& alpha);
|
torch::Tensor const& alpha);
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_moe_mm(
|
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
|
||||||
bool per_act_token, bool per_out_ch);
|
|
||||||
|
|
||||||
void cutlass_fp4_group_mm(
|
void cutlass_fp4_group_mm(
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
|
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
|
||||||
|
|
||||||
void get_cutlass_moe_mm_data(
|
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
|
||||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
|
||||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
|
||||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
|
||||||
const bool is_gated);
|
|
||||||
|
|
||||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
|
||||||
const torch::Tensor& expert_first_token_offset,
|
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
|
||||||
const int64_t n, const int64_t k, const bool swap_ab);
|
|
||||||
|
|
||||||
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
|
|
||||||
torch::Tensor& problem_sizes1,
|
|
||||||
torch::Tensor& problem_sizes2,
|
|
||||||
const torch::Tensor& expert_num_tokens,
|
|
||||||
const int64_t num_local_experts,
|
|
||||||
const int64_t padded_m, const int64_t n,
|
|
||||||
const int64_t k);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||||
torch::Tensor const& input, torch::Tensor const& input_scale,
|
torch::Tensor const& input, torch::Tensor const& input_scale,
|
||||||
bool is_sf_swizzled_layout);
|
bool is_sf_swizzled_layout);
|
||||||
@@ -343,7 +298,7 @@ void selective_scan_fwd(
|
|||||||
const std::optional<torch::Tensor>& query_start_loc,
|
const std::optional<torch::Tensor>& query_start_loc,
|
||||||
const std::optional<torch::Tensor>& cache_indices,
|
const std::optional<torch::Tensor>& cache_indices,
|
||||||
const std::optional<torch::Tensor>& has_initial_state,
|
const std::optional<torch::Tensor>& has_initial_state,
|
||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
|
||||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||||
const std::optional<torch::Tensor>& initial_state_idx,
|
const std::optional<torch::Tensor>& initial_state_idx,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
#include "nvfp4_utils.cuh"
|
#include "nvfp4_utils.cuh"
|
||||||
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
@@ -53,12 +54,27 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
|||||||
torch::Tensor const& output_scale_offset_by_experts);
|
torch::Tensor const& output_scale_offset_by_experts);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static bool nvfp4_quant_sm_supported() {
|
||||||
|
const int32_t sm = get_sm_version_num();
|
||||||
|
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||||
|
if (sm >= 100 && sm < 120) return true;
|
||||||
|
#endif
|
||||||
|
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||||
|
if (sm >= 120 && sm < 130) return true;
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||||
torch::Tensor const& input_sf,
|
torch::Tensor const& input_sf,
|
||||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||||
torch::Tensor& output_sf) {
|
torch::Tensor& output_sf) {
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||||
|
"No compiled nvfp4 quantization kernel for SM ",
|
||||||
|
get_sm_version_num(),
|
||||||
|
". Recompile with the appropriate CUDA arch.");
|
||||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||||
is_sf_swizzled_layout);
|
is_sf_swizzled_layout);
|
||||||
#endif
|
#endif
|
||||||
@@ -100,6 +116,10 @@ void scaled_fp4_experts_quant(
|
|||||||
torch::Tensor const& output_scale_offset_by_experts) {
|
torch::Tensor const& output_scale_offset_by_experts) {
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||||
|
"No compiled nvfp4 experts quantization kernel for SM ",
|
||||||
|
get_sm_version_num(),
|
||||||
|
". Recompile with the appropriate CUDA arch.");
|
||||||
return scaled_fp4_experts_quant_sm1xxa(
|
return scaled_fp4_experts_quant_sm1xxa(
|
||||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||||
output_scale_offset_by_experts);
|
output_scale_offset_by_experts);
|
||||||
@@ -112,6 +132,10 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
|
|||||||
torch::Tensor& input, torch::Tensor& input_sf) {
|
torch::Tensor& input, torch::Tensor& input_sf) {
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||||
|
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
|
||||||
|
get_sm_version_num(),
|
||||||
|
". Recompile with the appropriate CUDA arch.");
|
||||||
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
@@ -125,6 +149,11 @@ void silu_and_mul_scaled_fp4_experts_quant(
|
|||||||
torch::Tensor const& output_scale_offset_by_experts) {
|
torch::Tensor const& output_scale_offset_by_experts) {
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||||
|
"No compiled silu_and_mul nvfp4 experts quantization kernel "
|
||||||
|
"for SM ",
|
||||||
|
get_sm_version_num(),
|
||||||
|
". Recompile with the appropriate CUDA arch.");
|
||||||
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||||
output_scale_offset_by_experts);
|
output_scale_offset_by_experts);
|
||||||
|
|||||||
@@ -63,5 +63,17 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
|
|||||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
||||||
int runtimeVersion;
|
int runtimeVersion;
|
||||||
cudaRuntimeGetVersion(&runtimeVersion);
|
cudaRuntimeGetVersion(&runtimeVersion);
|
||||||
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
|
if (runtimeVersion < 12080) return false;
|
||||||
|
// Only report support when the SM-specific kernel was actually compiled in,
|
||||||
|
// so the Python-side backend selector does not choose CUTLASS and then hit
|
||||||
|
// TORCH_CHECK_NOT_IMPLEMENTED (or worse, fall through to Marlin).
|
||||||
|
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||||
|
if (cuda_device_capability >= 100 && cuda_device_capability < 120)
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
|
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||||
|
if (cuda_device_capability >= 120 && cuda_device_capability < 130)
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ struct MacheteCollectiveMma {
|
|||||||
struct DispatchPolicy {
|
struct DispatchPolicy {
|
||||||
constexpr static int Stages = PipelineStages;
|
constexpr static int Stages = PipelineStages;
|
||||||
using ClusterShape = ClusterShape_MNK;
|
using ClusterShape = ClusterShape_MNK;
|
||||||
|
using ArchTag = arch::Sm90;
|
||||||
using Schedule = KernelScheduleType;
|
using Schedule = KernelScheduleType;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
#include "scaled_mm_kernels.hpp"
|
|
||||||
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
|
||||||
torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
#include "scaled_mm_kernels.hpp"
|
|
||||||
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
|
||||||
torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
|
|
||||||
#include "scaled_mm_kernels.hpp"
|
|
||||||
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
|
||||||
torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales) {
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
|
||||||
torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
|
||||||
torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
|
||||||
torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales);
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
#include "scaled_mm_kernels.hpp"
|
|
||||||
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
|
||||||
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
|
|
||||||
b_scales, *bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
|
|
||||||
b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
#include "scaled_mm_kernels.hpp"
|
|
||||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
|
||||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
#include "scaled_mm_kernels.hpp"
|
|
||||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
|
||||||
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
|
|
||||||
b_scales, *bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
|
|
||||||
b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
#include "scaled_mm_kernels.hpp"
|
|
||||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
|
||||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -1,199 +0,0 @@
|
|||||||
#include <stddef.h>
|
|
||||||
#include <torch/all.h>
|
|
||||||
#include "cutlass/cutlass.h"
|
|
||||||
|
|
||||||
#include "scaled_mm_c2x.cuh"
|
|
||||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
|
||||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
|
||||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
|
||||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
|
||||||
|
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
|
||||||
|
|
||||||
using namespace vllm;
|
|
||||||
|
|
||||||
/*
|
|
||||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
|
||||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
|
||||||
*/
|
|
||||||
|
|
||||||
template <template <typename, typename> typename Epilogue,
|
|
||||||
typename... EpilogueArgs>
|
|
||||||
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
EpilogueArgs&&... epilogue_args) {
|
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
|
||||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
|
|
||||||
if (azp) {
|
|
||||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
|
||||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
|
||||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <template <typename, typename> typename Epilogue,
|
|
||||||
typename... EpilogueArgs>
|
|
||||||
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
EpilogueArgs&&... epilogue_args) {
|
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
|
||||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
|
|
||||||
if (azp) {
|
|
||||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
|
||||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
|
||||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <template <typename, typename> typename Epilogue,
|
|
||||||
typename... EpilogueArgs>
|
|
||||||
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
EpilogueArgs&&... epilogue_args) {
|
|
||||||
if (a.dtype() == torch::kInt8) {
|
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
|
||||||
Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
} else {
|
|
||||||
assert(out.dtype() == torch::kFloat16);
|
|
||||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
|
||||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
|
||||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
|
||||||
cutlass::bfloat16_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
|
||||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
|
||||||
cutlass::half_t, Epilogue>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
|
||||||
"currently bias dtype must match output dtype ", out.dtype());
|
|
||||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
|
||||||
out, a, b, a_scales, b_scales, *bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
|
||||||
out, a, b, a_scales, b_scales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
|
|
||||||
if (azp) {
|
|
||||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
|
||||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
|
||||||
} else {
|
|
||||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
|
||||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
#include "c3x/scaled_mm_helper.hpp"
|
|
||||||
#include "c3x/scaled_mm_kernels.hpp"
|
|
||||||
|
|
||||||
/*
|
|
||||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
|
||||||
NVIDIA GPUs with sm90a (Hopper).
|
|
||||||
*/
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
|
||||||
vllm::cutlass_scaled_mm_sm90_fp8,
|
|
||||||
vllm::cutlass_scaled_mm_sm90_int8,
|
|
||||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
||||||
|
|
||||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
|
||||||
azp, bias);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
@@ -1,420 +0,0 @@
|
|||||||
#include <cudaTypedefs.h>
|
|
||||||
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
#include "cutlass_extensions/common.hpp"
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
|
||||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
#endif
|
|
||||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
|
||||||
void cutlass_moe_mm_sm90(
|
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
|
||||||
bool per_act_token, bool per_out_ch);
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
|
||||||
void cutlass_moe_mm_sm100(
|
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
|
||||||
bool per_act_token, bool per_out_ch);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
|
||||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
|
||||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
|
||||||
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
|
||||||
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
|
||||||
void get_cutlass_moe_mm_data_caller(
|
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
|
||||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
|
||||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
|
||||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
|
||||||
const bool is_gated);
|
|
||||||
|
|
||||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
|
||||||
const torch::Tensor& expert_first_token_offset,
|
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
|
||||||
const int64_t n, const int64_t k, const bool swap_ab);
|
|
||||||
|
|
||||||
void get_cutlass_batched_moe_mm_data_caller(
|
|
||||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
|
||||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
|
||||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
|
||||||
const int64_t k);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
|
||||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|
||||||
// CUTLASS FP8 kernels need at least
|
|
||||||
// CUDA 12.0 on SM90 systems (Hopper)
|
|
||||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
|
||||||
|
|
||||||
#if defined CUDA_VERSION
|
|
||||||
if (cuda_device_capability >= 90) {
|
|
||||||
return CUDA_VERSION >= 12000;
|
|
||||||
} else if (cuda_device_capability >= 89) {
|
|
||||||
return CUDA_VERSION >= 12040;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
|
||||||
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
|
||||||
// and at least SM90 (Hopper)
|
|
||||||
|
|
||||||
#if defined CUDA_VERSION
|
|
||||||
if (cuda_device_capability >= 100) {
|
|
||||||
return CUDA_VERSION >= 12080;
|
|
||||||
} else if (cuda_device_capability >= 90) {
|
|
||||||
return CUDA_VERSION >= 12000;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
|
||||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
|
||||||
// or CUDA 12.8 and SM100 (Blackwell)
|
|
||||||
|
|
||||||
#if defined CUDA_VERSION
|
|
||||||
if (cuda_device_capability >= 100) {
|
|
||||||
return CUDA_VERSION >= 12080;
|
|
||||||
}
|
|
||||||
if (cuda_device_capability >= 90) {
|
|
||||||
return CUDA_VERSION >= 12030;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
// Checks for conformality
|
|
||||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
|
||||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
|
||||||
b.size(1) == c.size(1));
|
|
||||||
|
|
||||||
// Check for strides and alignment
|
|
||||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
|
||||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
|
||||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
|
||||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
|
||||||
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
|
||||||
bias->dim() == 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
|
||||||
int32_t version_num = get_sm_version_num();
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
|
||||||
if (version_num >= 120) {
|
|
||||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
|
||||||
if (version_num >= 100 && version_num < 120) {
|
|
||||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Guard against compilation issues for sm90 kernels
|
|
||||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
|
||||||
if (version_num >= 90 && version_num < 100) {
|
|
||||||
// Hopper
|
|
||||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
|
||||||
if (version_num == 89) {
|
|
||||||
// Ada Lovelace
|
|
||||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (version_num >= 80) {
|
|
||||||
// Ampere
|
|
||||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (version_num >= 75) {
|
|
||||||
// Turing
|
|
||||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false,
|
|
||||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
|
||||||
"CUDA device capability: ",
|
|
||||||
version_num);
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_moe_mm(
|
|
||||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
|
||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
|
||||||
bool per_act_token, bool per_out_ch) {
|
|
||||||
int32_t version_num = get_sm_version_num();
|
|
||||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
|
||||||
if (version_num >= 100 && version_num < 110) {
|
|
||||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
|
||||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
|
||||||
c_strides, per_act_token, per_out_ch);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
|
||||||
if (version_num >= 90 && version_num < 100) {
|
|
||||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
|
||||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
|
||||||
c_strides, per_act_token, per_out_ch);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false,
|
|
||||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
|
||||||
". Required capability: 90 or 100");
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_cutlass_moe_mm_data(
|
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
|
||||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
|
||||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
|
||||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
|
||||||
const bool is_gated) {
|
|
||||||
// This function currently gets compiled only if we have a valid cutlass moe
|
|
||||||
// mm to run it for.
|
|
||||||
int32_t version_num = get_sm_version_num();
|
|
||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
|
||||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
|
||||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
|
||||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
|
||||||
problem_sizes2, input_permutation,
|
|
||||||
output_permutation, num_experts, n, k,
|
|
||||||
blockscale_offsets, is_gated);
|
|
||||||
return;
|
|
||||||
#endif
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false,
|
|
||||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
|
||||||
"CUDA device capability: ",
|
|
||||||
version_num, ". Required capability: 90, 100, or 120");
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
|
||||||
const torch::Tensor& expert_first_token_offset,
|
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
|
||||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
|
||||||
int32_t version_num = get_sm_version_num();
|
|
||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
|
||||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
|
||||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
|
||||||
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
|
||||||
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
|
||||||
return;
|
|
||||||
#endif
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false,
|
|
||||||
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
|
||||||
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
|
||||||
version_num, ". Required capability: 90, 100, or 120");
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
|
|
||||||
torch::Tensor& problem_sizes1,
|
|
||||||
torch::Tensor& problem_sizes2,
|
|
||||||
const torch::Tensor& expert_num_tokens,
|
|
||||||
const int64_t num_local_experts,
|
|
||||||
const int64_t padded_m, const int64_t n,
|
|
||||||
const int64_t k) {
|
|
||||||
// This function currently gets compiled only if we have a valid cutlass moe
|
|
||||||
// mm to run it for.
|
|
||||||
int32_t version_num = get_sm_version_num();
|
|
||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
|
||||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
|
||||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
|
||||||
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
|
||||||
problem_sizes2, expert_num_tokens,
|
|
||||||
num_local_experts, padded_m, n, k);
|
|
||||||
return;
|
|
||||||
#endif
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
|
||||||
"No compiled get_cutlass_batched_moe_mm_data: no "
|
|
||||||
"cutlass_scaled_mm kernel "
|
|
||||||
"for CUDA device capability: ",
|
|
||||||
version_num,
|
|
||||||
". Required capability: 90, 100, or 120");
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b,
|
|
||||||
torch::Tensor const& a_scales,
|
|
||||||
torch::Tensor const& b_scales,
|
|
||||||
torch::Tensor const& azp_adj,
|
|
||||||
std::optional<torch::Tensor> const& azp,
|
|
||||||
std::optional<torch::Tensor> const& bias) {
|
|
||||||
// Checks for conformality
|
|
||||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
|
||||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
|
||||||
b.size(1) == c.size(1));
|
|
||||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
|
||||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
|
||||||
|
|
||||||
// Check for strides and alignment
|
|
||||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
|
||||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
|
||||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
|
||||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
|
||||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
|
||||||
|
|
||||||
// bias, azp, azp_adj are all 1d
|
|
||||||
// bias and azp_adj have n elements, azp has m elements
|
|
||||||
if (bias) {
|
|
||||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
|
||||||
}
|
|
||||||
if (azp) {
|
|
||||||
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
|
||||||
}
|
|
||||||
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
|
||||||
|
|
||||||
// azp & bias types
|
|
||||||
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
|
||||||
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
|
||||||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
|
||||||
"currently bias dtype must match output dtype ", c.dtype());
|
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
|
||||||
|
|
||||||
int32_t version_num = get_sm_version_num();
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
|
||||||
if (version_num >= 90) {
|
|
||||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
|
||||||
if (version_num == 89) {
|
|
||||||
// Ada Lovelace
|
|
||||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (version_num >= 80) {
|
|
||||||
// Ampere
|
|
||||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Turing
|
|
||||||
TORCH_CHECK(version_num >= 75);
|
|
||||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
|
||||||
return;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false,
|
|
||||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
|
||||||
"CUDA device capability: ",
|
|
||||||
version_num);
|
|
||||||
}
|
|
||||||
@@ -439,90 +439,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" -> ()");
|
" -> ()");
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
|
||||||
// quantization, as well as bias
|
|
||||||
ops.def(
|
|
||||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
|
||||||
" Tensor b, Tensor a_scales,"
|
|
||||||
" Tensor b_scales, Tensor? bias) -> ()");
|
|
||||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
|
||||||
// quantization.
|
|
||||||
ops.def(
|
|
||||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
|
||||||
" Tensor b, Tensor a_scales,"
|
|
||||||
" Tensor b_scales, Tensor azp_adj,"
|
|
||||||
" Tensor? azp, Tensor? bias) -> ()");
|
|
||||||
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
|
||||||
|
|
||||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
|
||||||
// capability
|
|
||||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
|
||||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
|
||||||
|
|
||||||
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
|
||||||
// capability
|
|
||||||
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
|
||||||
ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported);
|
|
||||||
|
|
||||||
// CUTLASS w8a8 grouped GEMM
|
|
||||||
ops.def(
|
|
||||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
|
||||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
|
||||||
" Tensor problem_sizes, Tensor a_strides, "
|
|
||||||
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
|
||||||
" bool per_out_ch) -> ()");
|
|
||||||
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
|
|
||||||
|
|
||||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
|
||||||
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
|
||||||
// (token start indices of each expert). In addition to this, it computes
|
|
||||||
// problem sizes for each expert's multiplication used by the two mms called
|
|
||||||
// from fused MoE operation, and arrays with permutations required to shuffle
|
|
||||||
// and de-shuffle the input/output of the fused operation.
|
|
||||||
ops.def(
|
|
||||||
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
|
||||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
|
||||||
" Tensor! input_permutation, "
|
|
||||||
" Tensor! output_permutation, int num_experts, "
|
|
||||||
" int n, int k, Tensor? blockscale_offsets, "
|
|
||||||
" bool is_gated) -> ()");
|
|
||||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
|
||||||
|
|
||||||
// compute per-expert problem sizes from expert_first_token_offset
|
|
||||||
// produced by vLLM's moe_permute kernel
|
|
||||||
ops.def(
|
|
||||||
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
|
||||||
" Tensor expert_first_token_offset, "
|
|
||||||
" Tensor! problem_sizes1, "
|
|
||||||
" Tensor! problem_sizes2, "
|
|
||||||
" int n, int k, bool swap_ab) -> ()");
|
|
||||||
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA,
|
|
||||||
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
|
|
||||||
|
|
||||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
|
||||||
// GEMM in batched expert format. It takes expert_num_tokens
|
|
||||||
// as an input, and computes expert_offsets (token start indices of each
|
|
||||||
// expert). In addition to this, it computes problem sizes for each expert's
|
|
||||||
// multiplication used by the two mms called from fused MoE operation.
|
|
||||||
ops.def(
|
|
||||||
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
|
|
||||||
" Tensor! problem_sizes1, "
|
|
||||||
" Tensor! problem_sizes2, "
|
|
||||||
" Tensor expert_num_tokens, "
|
|
||||||
" int num_local_experts, int padded_m, "
|
|
||||||
" int n, int k) -> ()");
|
|
||||||
ops.impl("get_cutlass_batched_moe_mm_data", torch::kCUDA,
|
|
||||||
&get_cutlass_batched_moe_mm_data);
|
|
||||||
|
|
||||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
|
||||||
ops.def(
|
|
||||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
|
||||||
"bool");
|
|
||||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
|
||||||
&cutlass_scaled_mm_supports_block_fp8);
|
|
||||||
|
|
||||||
// SM100 CUTLASS MLA decode
|
// SM100 CUTLASS MLA decode
|
||||||
ops.def(
|
ops.def(
|
||||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
|
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
|
||||||
@@ -640,7 +556,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor? cache_indices,"
|
"Tensor? cache_indices,"
|
||||||
"Tensor? has_initial_state,"
|
"Tensor? has_initial_state,"
|
||||||
"Tensor! ssm_states,"
|
"Tensor! ssm_states,"
|
||||||
"int pad_slot_id,"
|
"int null_block_id,"
|
||||||
"int block_size,"
|
"int block_size,"
|
||||||
"Tensor? block_idx_first_scheduled_token,"
|
"Tensor? block_idx_first_scheduled_token,"
|
||||||
"Tensor? block_idx_last_scheduled_token,"
|
"Tensor? block_idx_last_scheduled_token,"
|
||||||
|
|||||||
@@ -590,7 +590,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
# Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
|
# Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
|
||||||
# https://docs.flashinfer.ai/installation.html
|
# https://docs.flashinfer.ai/installation.html
|
||||||
# From versions.json: .flashinfer.version
|
# From versions.json: .flashinfer.version
|
||||||
ARG FLASHINFER_VERSION=0.6.6
|
# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798)
|
||||||
|
# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on
|
||||||
|
# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738
|
||||||
|
ARG FLASHINFER_VERSION=0.6.7
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
|
uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
|
||||||
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||||
|
|||||||
@@ -217,13 +217,16 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
|
|||||||
|
|
||||||
|
|
||||||
# build flashinfer for torch nightly from source around 10 mins
|
# build flashinfer for torch nightly from source around 10 mins
|
||||||
# release version: v0.6.6
|
# release version: v0.6.7
|
||||||
|
# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798)
|
||||||
|
# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on
|
||||||
|
# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738
|
||||||
# todo(elainewy): cache flashinfer build result for faster build
|
# todo(elainewy): cache flashinfer build result for faster build
|
||||||
ENV CCACHE_DIR=/root/.cache/ccache
|
ENV CCACHE_DIR=/root/.cache/ccache
|
||||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||||
--mount=type=cache,target=/root/.cache/uv \
|
--mount=type=cache,target=/root/.cache/uv \
|
||||||
echo "git clone flashinfer..." \
|
echo "git clone flashinfer..." \
|
||||||
&& git clone --depth 1 --branch v0.6.6 --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
&& git clone --depth 1 --branch v0.6.7 --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
||||||
&& cd flashinfer \
|
&& cd flashinfer \
|
||||||
&& git submodule update --init --recursive \
|
&& git submodule update --init --recursive \
|
||||||
&& echo "finish git clone flashinfer..." \
|
&& echo "finish git clone flashinfer..." \
|
||||||
|
|||||||
@@ -111,12 +111,9 @@ CMD ["/bin/bash"]
|
|||||||
|
|
||||||
FROM vllm-base AS vllm-openai
|
FROM vllm-base AS vllm-openai
|
||||||
|
|
||||||
# install additional dependencies for openai api server
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
uv pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope
|
|
||||||
|
|
||||||
# install development dependencies (for testing)
|
# install development dependencies (for testing)
|
||||||
RUN uv pip install -e tests/vllm_test_utils
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv pip install -e tests/vllm_test_utils
|
||||||
|
|
||||||
# install NIXL and UCX from source code
|
# install NIXL and UCX from source code
|
||||||
ARG UCX_VERSION=e5d98879705239d254ede40b4a52891850cb5349
|
ARG UCX_VERSION=e5d98879705239d254ede40b4a52891850cb5349
|
||||||
|
|||||||
@@ -68,7 +68,7 @@
|
|||||||
"default": "true"
|
"default": "true"
|
||||||
},
|
},
|
||||||
"FLASHINFER_VERSION": {
|
"FLASHINFER_VERSION": {
|
||||||
"default": "0.6.6"
|
"default": "0.6.7"
|
||||||
},
|
},
|
||||||
"GDRCOPY_CUDA_VERSION": {
|
"GDRCOPY_CUDA_VERSION": {
|
||||||
"default": "12.8"
|
"default": "12.8"
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ Now supports 6 types of connectors:
|
|||||||
|
|
||||||
- **ExampleConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of ExampleConnector disaggregated prefilling.
|
- **ExampleConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of ExampleConnector disaggregated prefilling.
|
||||||
- **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
|
- **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
|
||||||
- **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md).
|
- **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). For feature compatibility details, see [NixlConnector Compatibility Matrix](nixl_connector_compatibility.md).
|
||||||
- **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling.
|
- **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling.
|
||||||
- **MooncakeConnector**: refer to [examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh) for the example usage of ExampleConnector disaggregated prefilling. For detailed usage guide, see [MooncakeConnector Usage Guide](mooncake_connector_usage.md).
|
- **MooncakeConnector**: refer to [examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh) for the example usage of ExampleConnector disaggregated prefilling. For detailed usage guide, see [MooncakeConnector Usage Guide](mooncake_connector_usage.md).
|
||||||
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
|
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
|
||||||
|
|||||||
104
docs/features/nixl_connector_compatibility.md
Normal file
104
docs/features/nixl_connector_compatibility.md
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# NixlConnector Compatibility Matrix
|
||||||
|
|
||||||
|
This page documents the feature compatibility of **disaggregated prefilling with the NixlConnector**. For general usage instructions, see the [NixlConnector Usage Guide](nixl_connector_usage.md). For an overview of disaggregated prefilling, see [Disaggregated Prefilling](disagg_prefill.md).
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
This page reflects the current state of the codebase and is subject to change as features evolve. Entries marked 🟠 or ❌ may link to tracking issues. See the [NIXL connector roadmap](https://github.com/vllm-project/vllm/issues/33702) for upcoming feature development.
|
||||||
|
|
||||||
|
**Legend:**
|
||||||
|
|
||||||
|
- ✅ = Fully supported
|
||||||
|
- 🟠 = Partial support (see footnotes)
|
||||||
|
- ❌ = Not supported
|
||||||
|
- ❔ = Unknown / not yet validated
|
||||||
|
- 🚧 = Work in progress
|
||||||
|
|
||||||
|
!!! info "Universally supported features"
|
||||||
|
The following features work with **all** model architectures when using NixlConnector PD disaggregated serving:
|
||||||
|
|
||||||
|
[Chunked Prefill](../configuration/optimization.md#chunked-prefill) |
|
||||||
|
[APC (Prefix Caching)](automatic_prefix_caching.md) |
|
||||||
|
[Data Parallel](../serving/data_parallel_deployment.md) |
|
||||||
|
CUDA graph |
|
||||||
|
Logprobs |
|
||||||
|
Prompt Logprobs |
|
||||||
|
[Prompt Embeds](prompt_embeds.md) |
|
||||||
|
Multiple NIXL backends (UCX, GDS, LIBFABRIC, etc.)
|
||||||
|
|
||||||
|
## Model Architecture x Capability
|
||||||
|
|
||||||
|
<style>
|
||||||
|
td:not(:first-child) {
|
||||||
|
text-align: center !important;
|
||||||
|
}
|
||||||
|
td {
|
||||||
|
padding: 0.5rem !important;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
th {
|
||||||
|
padding: 0.5rem !important;
|
||||||
|
min-width: 0 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
th:not(:first-child) {
|
||||||
|
writing-mode: vertical-lr;
|
||||||
|
transform: rotate(180deg)
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|
||||||
|
| Model type | <abbr title="Basic Prefill/Decode disaggregation">Basic PD</abbr> | <abbr title="Speculative Decoding">Spec Decode</abbr> | <abbr title="Heterogeneous Tensor Parallelism (P TP != D TP)">Hetero TP</abbr> | <abbr title="Cross-layer blocks optimization">Cross-layer blocks</abbr> | <abbr title="Sliding Window Attention">SWA</abbr> | <abbr title="CPU host buffer offload (e.g. TPU)">Host buffer</abbr> | <abbr title="Different block sizes on P and D">Hetero block size</abbr> |
|
||||||
|
| - | - | - | - | - | - | - | - |
|
||||||
|
| Dense Transformers | ✅ | ✅<sup>1</sup> | ✅ | ✅<sup>2</sup> | ✅ | ✅ | 🟠<sup>3</sup> |
|
||||||
|
| MLA (e.g. DeepSeek-V2/V3) | ✅ | ✅<sup>1</sup> | 🟠<sup>4</sup> | ✅<sup>2</sup> | ✅ | ✅ | 🟠<sup>3</sup> |
|
||||||
|
| Sparse MLA (e.g. DeepSeek-V3.2) | ✅ | ✅<sup>1</sup> | 🟠<sup>4</sup> | ✅<sup>2</sup> | ✅ | ✅ | 🟠<sup>3</sup> |
|
||||||
|
| Hybrid SSM / Mamba | ✅ | ❔ | 🚧<sup>5</sup> | ❌ | ✅ | ✅ | ❌<sup>6</sup> |
|
||||||
|
| MoE | ✅ | ✅<sup>1</sup> | ✅ | ✅<sup>2</sup> | ✅ | ✅ | 🟠<sup>3</sup> |
|
||||||
|
| Multimodal | ❔ | ❔ | ❔ | ❔ | ❔ | ❔ | ❔ |
|
||||||
|
| Encoder-Decoder | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
|
||||||
|
<sup>1</sup> P and D instances must use the same speculation configuration.
|
||||||
|
|
||||||
|
<sup>2</sup> Requires `FLASH_ATTN` or `FLASHINFER` backend **and** `HND` KV cache layout. Enable via `--kv-transfer-config '{"kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'`.
|
||||||
|
|
||||||
|
<sup>3</sup> Supported only when HMA is **not** required (i.e., non-hybrid models). Block IDs are remapped automatically. Only P block size < D block size is supported.
|
||||||
|
|
||||||
|
<sup>4</sup> MLA KV cache is replicated across TP workers, so heterogeneous TP works but there is no head-splitting. When P TP > D TP, only a single read is executed (redundant ranks are skipped). D TP > P TP also works.
|
||||||
|
|
||||||
|
<sup>5</sup> Hybrid SSM (Mamba) models require **homogeneous TP** (`P TP == D TP`). Heterogeneous TP is not yet supported for Mamba layers.
|
||||||
|
|
||||||
|
<sup>6</sup> HMA (required by hybrid models) does not support different remote block sizes.
|
||||||
|
|
||||||
|
## Configuration Notes
|
||||||
|
|
||||||
|
### What must match between P and D
|
||||||
|
|
||||||
|
By default, a **compatibility hash** is checked during handshake. P and D instances must agree on:
|
||||||
|
|
||||||
|
- vLLM version and NIXL connector version
|
||||||
|
- Model (architecture, dtype, number of KV heads, head size, number of hidden layers)
|
||||||
|
- Attention backend
|
||||||
|
- KV cache dtype (`cache_dtype`)
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Disable the hash check with `--kv-transfer-config '{"kv_connector_extra_config": {"enforce_handshake_compat": false}}'` at your own risk.
|
||||||
|
|
||||||
|
### What can safely differ between P and D
|
||||||
|
|
||||||
|
- `tensor-parallel-size` (heterogeneous TP, subject to model restrictions above)
|
||||||
|
- `block-size` (heterogeneous block size, subject to restrictions above)
|
||||||
|
- Number of KV cache blocks (determined by available memory on each instance)
|
||||||
|
|
||||||
|
### KV cache layout
|
||||||
|
|
||||||
|
- NixlConnector defaults to **`HND`** layout for optimal transfer performance (non-MLA models).
|
||||||
|
- `NHD` layout is supported but does **not** allow heterogeneous TP head splitting.
|
||||||
|
- Experimental `HND` ↔ `NHD` permute: enable via `--kv-transfer-config '{"enable_permute_local_kv": true}'`. Not supported with HMA.
|
||||||
|
|
||||||
|
### Quantized KV cache
|
||||||
|
|
||||||
|
[Quantized KV cache](quantization/quantized_kvcache.md) (e.g., FP8) requires both P and D instances to use the **same** `cache_dtype`. Mismatched cache dtypes will fail the compatibility hash check during handshake.
|
||||||
|
|
||||||
|
- **Static quantization** (scales loaded from checkpoint): ✅ Supported. Scales are loaded independently by each instance from the model checkpoint.
|
||||||
|
- **Dynamic quantization** (scales computed at runtime): ❌ Not supported. Per-block scales are not transferred alongside KV cache data.
|
||||||
|
- **Packed-layout scales** (scales stored inline with weights): ✅ Supported. Scales are transferred together with the KV cache blocks.
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer.
|
NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer.
|
||||||
|
|
||||||
|
For feature compatibility details (supported model architectures, TP configurations, and feature interactions), see the [NixlConnector Compatibility Matrix](nixl_connector_compatibility.md).
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|||||||
@@ -505,7 +505,7 @@ Here is a summary of a plugin file:
|
|||||||
|
|
||||||
# adjust request. e.g.: set skip special tokens
|
# adjust request. e.g.: set skip special tokens
|
||||||
# to False for tool call output.
|
# to False for tool call output.
|
||||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
def adjust_request(self, request: ChatCompletionRequest | ResponsesRequest) -> ChatCompletionRequest | ResponsesRequest:
|
||||||
return request
|
return request
|
||||||
|
|
||||||
# implement the tool call parse for stream call
|
# implement the tool call parse for stream call
|
||||||
|
|||||||
@@ -1,7 +1,14 @@
|
|||||||
<!-- markdownlint-disable MD041 MD051 -->
|
<!-- markdownlint-disable MD041 MD051 -->
|
||||||
--8<-- [start:installation]
|
--8<-- [start:installation]
|
||||||
|
|
||||||
vLLM supports AMD GPUs with ROCm 6.3 or above. Pre-built wheels are available for ROCm 7.0.
|
vLLM supports AMD GPUs with ROCm 6.3 or above. Pre-built wheels are available for ROCm 7.0 and ROCm 7.2.1.
|
||||||
|
|
||||||
|
#### Prebuilt Wheels
|
||||||
|
|
||||||
|
| ROCm Variant | Python Version | ROCm Version | glibc Requirement | Supported Versions |
|
||||||
|
| ------------ | -------------- | ------------ | ----------------- | ------------------ |
|
||||||
|
| `rocm700` | 3.12 | 7.0 | >= 2.35 | `0.14.0` to `0.18.0` |
|
||||||
|
| `rocm721` | 3.12 | 7.2.1 | >= 2.35 | Nightly releases after commit `171775f306a333a9cf105bfd533bf3e113d401d9` |
|
||||||
|
|
||||||
--8<-- [end:installation]
|
--8<-- [end:installation]
|
||||||
--8<-- [start:requirements]
|
--8<-- [start:requirements]
|
||||||
@@ -23,26 +30,112 @@ If you need a different ROCm version or want to use an existing PyTorch installa
|
|||||||
To install the latest version of vLLM for Python 3.12, ROCm 7.0 and `glibc >= 2.35`.
|
To install the latest version of vLLM for Python 3.12, ROCm 7.0 and `glibc >= 2.35`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/
|
uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/ --upgrade
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
You can find out about which ROCm version the latest vLLM supports by checking the index in extra-index-url [https://wheels.vllm.ai/rocm/](https://wheels.vllm.ai/rocm/) .
|
You can find out about which ROCm version the latest vLLM supports by checking the `vllm` package in index in extra-index-url <https://wheels.vllm.ai/rocm/> at [https://wheels.vllm.ai/rocm/vllm](https://wheels.vllm.ai/rocm/vllm) .
|
||||||
|
|
||||||
|
Another approach is that you can use this following commands to automatically extract the wheel variants:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# automatically extract the available rocm variant
|
||||||
|
export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/vllm | grep -oP 'rocm\d+' | head -1)
|
||||||
|
|
||||||
|
# automatically extract the vLLM version
|
||||||
|
export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/vllm | grep -oP 'vllm-\K[0-9.]+' | head -1)
|
||||||
|
|
||||||
|
# inspect if the ROCm version is compatible with your environment
|
||||||
|
echo $VLLM_ROCM_VARIANT
|
||||||
|
echo $VLLM_VERSION
|
||||||
|
```
|
||||||
|
|
||||||
To install a specific version and ROCm variant of vLLM wheel.
|
To install a specific version and ROCm variant of vLLM wheel.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/0.15.0/rocm700
|
# version without the `v`
|
||||||
|
uv pip install vllm==${VLLM_VERSION} --extra-index-url https://wheels.vllm.ai/rocm/${VLLM_VERSION}/${VLLM_ROCM_VARIANT}
|
||||||
|
|
||||||
|
# Example
|
||||||
|
uv pip install vllm==0.18.0 --extra-index-url https://wheels.vllm.ai/rocm/0.18.0/rocm700
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! warning "Caveats for using `pip`"
|
!!! warning "Caveats for using `pip`"
|
||||||
|
|
||||||
We recommend leveraging `uv` to install vLLM wheel. Using `pip` to install from custom indices is cumbersome, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install wheel from custom index if exact versions of all packages are specified exactly. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes).
|
We recommend leveraging `uv` to install the vLLM wheel. Using `pip` to install from custom indices is cumbersome because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version. This makes it difficult to install a wheel from a custom index unless exact versions of all packages are specified. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes).
|
||||||
|
|
||||||
If you insist on using `pip`, you have to specify the exact vLLM version and full URL of the wheel path `https://wheels.vllm.ai/rocm/<version>/<rocm-variant>` (which can be obtained from the web page).
|
If you insist on using `pip`, you need to specify the exact vLLM version in the package name and provide the custom index URL `https://wheels.vllm.ai/rocm/${VLLM_VERSION}/${VLLM_ROCM_VARIANT}` via `--extra-index-url`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install vllm==0.15.0+rocm700 --extra-index-url https://wheels.vllm.ai/rocm/0.15.0/rocm700
|
pip install vllm==0.18.0+rocm700 --extra-index-url https://wheels.vllm.ai/rocm/0.18.0/rocm700
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Install the latest code
|
||||||
|
|
||||||
|
LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for every commit since commit `171775f306a333a9cf105bfd533bf3e113d401d9` on <https://wheels.vllm.ai/rocm/nightly/>. The custom index to be used is `https://wheels.vllm.ai/rocm/nightly/${VLLM_ROCM_VARIANT}`
|
||||||
|
|
||||||
|
**NOTE:** The first ROCm Variant that supports nightly wheel is ROCm 7.2.1
|
||||||
|
|
||||||
|
To install from latest nightly index, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# automatically extract the available rocm variant
|
||||||
|
export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/nightly | \
|
||||||
|
grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g')
|
||||||
|
|
||||||
|
# inspect if the ROCm version is compatible with your environment
|
||||||
|
echo $VLLM_ROCM_VARIANT
|
||||||
|
|
||||||
|
uv pip install --pre vllm \
|
||||||
|
--extra-index-url https://wheels.vllm.ai/rocm/nightly/${VLLM_ROCM_VARIANT} \
|
||||||
|
--index-strategy unsafe-best-match
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Install specific revisions
|
||||||
|
|
||||||
|
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL, example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VLLM_COMMIT=5b8c30d62b754b575e043ce2fc0dcbf8a64f6306
|
||||||
|
|
||||||
|
export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT} | \
|
||||||
|
grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g')
|
||||||
|
|
||||||
|
# Extract the version from the wheel URL
|
||||||
|
export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT}/vllm/ | \
|
||||||
|
grep -oP 'vllm-\K[^-]+' | head -1 | sed 's/%2B/+/g')
|
||||||
|
|
||||||
|
# inspect the version if it is compatible with the ROCm version of your environment
|
||||||
|
echo $VLLM_ROCM_VARIANT
|
||||||
|
echo $VLLM_VERSION
|
||||||
|
|
||||||
|
uv pip install vllm==${VLLM_VERSION} \
|
||||||
|
--extra-index-url https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT} \
|
||||||
|
--index-strategy unsafe-best-match
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! warning "`pip` caveat"
|
||||||
|
|
||||||
|
Using `pip` to install from nightly indices is _not supported_, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes).
|
||||||
|
|
||||||
|
If you insist on using `pip`, you need to specify the exact vLLM version in the package name and provide the custom index URL (which can be obtained from the web page).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VLLM_COMMIT=5b8c30d62b754b575e043ce2fc0dcbf8a64f6306
|
||||||
|
|
||||||
|
export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT} | \
|
||||||
|
grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g')
|
||||||
|
|
||||||
|
# Extract the version from the wheel URL
|
||||||
|
export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT}/vllm/ | \
|
||||||
|
grep -oP 'vllm-\K[^-]+' | head -1 | sed 's/%2B/+/g')
|
||||||
|
|
||||||
|
# inspect the version if it is compatible with the ROCm version of your environment
|
||||||
|
echo $VLLM_ROCM_VARIANT
|
||||||
|
echo $VLLM_VERSION
|
||||||
|
|
||||||
|
pip install vllm==${VLLM_VERSION} \
|
||||||
|
--extra-index-url https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT}
|
||||||
```
|
```
|
||||||
|
|
||||||
--8<-- [end:pre-built-wheels]
|
--8<-- [end:pre-built-wheels]
|
||||||
@@ -193,6 +286,24 @@ docker run --rm \
|
|||||||
--model Qwen/Qwen3-0.6B
|
--model Qwen/Qwen3-0.6B
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To use the docker image as base for development, you can launch it in interactive session through overriding the entrypoint.
|
||||||
|
|
||||||
|
???+ console "Commands"
|
||||||
|
```bash
|
||||||
|
docker run --rm -it \
|
||||||
|
--group-add=video \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--device /dev/kfd \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--env "HF_TOKEN=$HF_TOKEN" \
|
||||||
|
--network=host \
|
||||||
|
--ipc=host \
|
||||||
|
--entrypoint /bin/bash \
|
||||||
|
vllm/vllm-openai-rocm:<tag>
|
||||||
|
```
|
||||||
|
|
||||||
#### Use AMD's Docker Images (Deprecated)
|
#### Use AMD's Docker Images (Deprecated)
|
||||||
|
|
||||||
!!! warning "Deprecated"
|
!!! warning "Deprecated"
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ Restrict domains that vLLM can access for media URLs by setting
|
|||||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||||
|
|
||||||
|
This protection applies to both the online serving API (multimodal inputs) and
|
||||||
|
the **batch runner** (`vllm run-batch`), where `file_url` values in batch
|
||||||
|
transcription/translation requests are validated against the same allowlist.
|
||||||
|
|
||||||
Without domain restrictions, a malicious user could supply URLs that:
|
Without domain restrictions, a malicious user could supply URLs that:
|
||||||
|
|
||||||
- **Target internal services**: Access internal network endpoints, cloud metadata
|
- **Target internal services**: Access internal network endpoints, cloud metadata
|
||||||
|
|||||||
@@ -4,9 +4,10 @@
|
|||||||
experimental support for tensor-parallel inference with torchrun,
|
experimental support for tensor-parallel inference with torchrun,
|
||||||
see https://github.com/vllm-project/vllm/issues/11400 for
|
see https://github.com/vllm-project/vllm/issues/11400 for
|
||||||
the motivation and use case for this example.
|
the motivation and use case for this example.
|
||||||
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
|
run the script with `torchrun --nproc-per-node=4 torchrun_example.py`,
|
||||||
the argument 2 should match the `tensor_parallel_size` below.
|
the argument `4` should match the product of `tensor_parallel_size` and
|
||||||
see `tests/distributed/test_torchrun_example.py` for the unit test.
|
`pipeline_parallel_size` below. see `tests/distributed/test_torchrun_example.py`
|
||||||
|
for the unit test.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|||||||
@@ -26,8 +26,13 @@ MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash disagg_1e1p1d_example.sh
|
|||||||
|
|
||||||
# Use specific storage path
|
# Use specific storage path
|
||||||
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh
|
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh
|
||||||
|
|
||||||
|
# Run on XPU; scripts switch from CUDA_VISIBLE_DEVICES to ZE_AFFINITY_MASK
|
||||||
|
DEVICE_PLATFORM=xpu GPU_E=0 GPU_PD=1 bash disagg_1e1pd_example.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`DEVICE_PLATFORM` defaults to `cuda`. Set `DEVICE_PLATFORM=xpu` when running these examples on Intel GPUs so the scripts use `ZE_AFFINITY_MASK` instead of `CUDA_VISIBLE_DEVICES` for device selection.
|
||||||
|
|
||||||
## Encoder Instances
|
## Encoder Instances
|
||||||
|
|
||||||
Encoder engines should be launched with the following flags:
|
Encoder engines should be launched with the following flags:
|
||||||
|
|||||||
@@ -19,11 +19,29 @@ GPU_E="${GPU_E:-2}"
|
|||||||
GPU_P="${GPU_P:-2}"
|
GPU_P="${GPU_P:-2}"
|
||||||
GPU_D="${GPU_D:-3}"
|
GPU_D="${GPU_D:-3}"
|
||||||
|
|
||||||
|
# Device platform and affinity env name.
|
||||||
|
# DEVICE_PLATFORM supports: cuda, xpu
|
||||||
|
DEVICE_PLATFORM="${DEVICE_PLATFORM:-cuda}"
|
||||||
|
if [[ -z "${DEVICE_AFFINITY_ENV:-}" ]]; then
|
||||||
|
if [[ "${DEVICE_PLATFORM,,}" == "xpu" ]]; then
|
||||||
|
DEVICE_AFFINITY_ENV="ZE_AFFINITY_MASK"
|
||||||
|
else
|
||||||
|
DEVICE_AFFINITY_ENV="CUDA_VISIBLE_DEVICES"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||||
|
|
||||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||||
|
|
||||||
|
# Serve args
|
||||||
|
GPU_MEMORY_UTILIZATION_E="${GPU_MEMORY_UTILIZATION_E:-0.01}"
|
||||||
|
GPU_MEMORY_UTILIZATION_P="${GPU_MEMORY_UTILIZATION_P:-0.7}"
|
||||||
|
GPU_MEMORY_UTILIZATION_D="${GPU_MEMORY_UTILIZATION_D:-0.7}"
|
||||||
|
MAX_NUM_SEQS="${MAX_NUM_SEQS:-128}"
|
||||||
|
MAX_MODEL_LEN="${MAX_MODEL_LEN:-32768}"
|
||||||
|
|
||||||
export UCX_TLS=all
|
export UCX_TLS=all
|
||||||
export UCX_NET_DEVICES=all
|
export UCX_NET_DEVICES=all
|
||||||
|
|
||||||
@@ -92,14 +110,14 @@ mkdir -p "$EC_SHARED_STORAGE_PATH"
|
|||||||
###############################################################################
|
###############################################################################
|
||||||
# Encoder worker
|
# Encoder worker
|
||||||
###############################################################################
|
###############################################################################
|
||||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
env "$DEVICE_AFFINITY_ENV=$GPU_E" vllm serve "$MODEL" \
|
||||||
--gpu-memory-utilization 0.01 \
|
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_E" \
|
||||||
--port "$ENCODE_PORT" \
|
--port "$ENCODE_PORT" \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--enable-request-id-headers \
|
--enable-request-id-headers \
|
||||||
--no-enable-prefix-caching \
|
--no-enable-prefix-caching \
|
||||||
--max-num-batched-tokens 114688 \
|
--max-num-batched-tokens 114688 \
|
||||||
--max-num-seqs 128 \
|
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||||
--ec-transfer-config '{
|
--ec-transfer-config '{
|
||||||
"ec_connector": "ECExampleConnector",
|
"ec_connector": "ECExampleConnector",
|
||||||
@@ -115,15 +133,16 @@ PIDS+=($!)
|
|||||||
###############################################################################
|
###############################################################################
|
||||||
# Prefill worker
|
# Prefill worker
|
||||||
###############################################################################
|
###############################################################################
|
||||||
CUDA_VISIBLE_DEVICES="$GPU_P" \
|
env "$DEVICE_AFFINITY_ENV=$GPU_P" \
|
||||||
UCX_NET_DEVICES=all \
|
UCX_NET_DEVICES=all \
|
||||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
||||||
vllm serve "$MODEL" \
|
vllm serve "$MODEL" \
|
||||||
--gpu-memory-utilization 0.7 \
|
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_P" \
|
||||||
--port "$PREFILL_PORT" \
|
--port "$PREFILL_PORT" \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--enable-request-id-headers \
|
--enable-request-id-headers \
|
||||||
--max-num-seqs 128 \
|
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||||
|
--max-model-len "$MAX_MODEL_LEN" \
|
||||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||||
--ec-transfer-config '{
|
--ec-transfer-config '{
|
||||||
"ec_connector": "ECExampleConnector",
|
"ec_connector": "ECExampleConnector",
|
||||||
@@ -143,15 +162,16 @@ PIDS+=($!)
|
|||||||
###############################################################################
|
###############################################################################
|
||||||
# Decode worker
|
# Decode worker
|
||||||
###############################################################################
|
###############################################################################
|
||||||
CUDA_VISIBLE_DEVICES="$GPU_D" \
|
env "$DEVICE_AFFINITY_ENV=$GPU_D" \
|
||||||
UCX_NET_DEVICES=all \
|
UCX_NET_DEVICES=all \
|
||||||
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
||||||
vllm serve "$MODEL" \
|
vllm serve "$MODEL" \
|
||||||
--gpu-memory-utilization 0.7 \
|
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_D" \
|
||||||
--port "$DECODE_PORT" \
|
--port "$DECODE_PORT" \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--enable-request-id-headers \
|
--enable-request-id-headers \
|
||||||
--max-num-seqs 128 \
|
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||||
|
--max-model-len "$MAX_MODEL_LEN" \
|
||||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||||
--kv-transfer-config '{
|
--kv-transfer-config '{
|
||||||
"kv_connector": "NixlConnector",
|
"kv_connector": "NixlConnector",
|
||||||
|
|||||||
@@ -17,11 +17,28 @@ PROXY_PORT="${PROXY_PORT:-10001}"
|
|||||||
GPU_E="${GPU_E:-0}"
|
GPU_E="${GPU_E:-0}"
|
||||||
GPU_PD="${GPU_PD:-1}"
|
GPU_PD="${GPU_PD:-1}"
|
||||||
|
|
||||||
|
# Device platform and affinity env name.
|
||||||
|
# DEVICE_PLATFORM supports: cuda, xpu
|
||||||
|
DEVICE_PLATFORM="${DEVICE_PLATFORM:-cuda}"
|
||||||
|
if [[ -z "${DEVICE_AFFINITY_ENV:-}" ]]; then
|
||||||
|
if [[ "${DEVICE_PLATFORM,,}" == "xpu" ]]; then
|
||||||
|
DEVICE_AFFINITY_ENV="ZE_AFFINITY_MASK"
|
||||||
|
else
|
||||||
|
DEVICE_AFFINITY_ENV="CUDA_VISIBLE_DEVICES"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||||
|
|
||||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||||
|
|
||||||
|
# Serve args
|
||||||
|
GPU_MEMORY_UTILIZATION_E="${GPU_MEMORY_UTILIZATION_E:-0.01}"
|
||||||
|
GPU_MEMORY_UTILIZATION_PD="${GPU_MEMORY_UTILIZATION_PD:-0.7}"
|
||||||
|
MAX_NUM_SEQS="${MAX_NUM_SEQS:-128}"
|
||||||
|
MAX_MODEL_LEN="${MAX_MODEL_LEN:-32768}"
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Helpers
|
# Helpers
|
||||||
###############################################################################
|
###############################################################################
|
||||||
@@ -86,14 +103,14 @@ mkdir -p "$EC_SHARED_STORAGE_PATH"
|
|||||||
###############################################################################
|
###############################################################################
|
||||||
# Encoder worker
|
# Encoder worker
|
||||||
###############################################################################
|
###############################################################################
|
||||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
env "$DEVICE_AFFINITY_ENV=$GPU_E" vllm serve "$MODEL" \
|
||||||
--gpu-memory-utilization 0.01 \
|
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_E" \
|
||||||
--port "$ENCODE_PORT" \
|
--port "$ENCODE_PORT" \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--enable-request-id-headers \
|
--enable-request-id-headers \
|
||||||
--no-enable-prefix-caching \
|
--no-enable-prefix-caching \
|
||||||
--max-num-batched-tokens 114688 \
|
--max-num-batched-tokens 114688 \
|
||||||
--max-num-seqs 128 \
|
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||||
--ec-transfer-config '{
|
--ec-transfer-config '{
|
||||||
"ec_connector": "ECExampleConnector",
|
"ec_connector": "ECExampleConnector",
|
||||||
@@ -109,12 +126,13 @@ PIDS+=($!)
|
|||||||
###############################################################################
|
###############################################################################
|
||||||
# Prefill+Decode worker
|
# Prefill+Decode worker
|
||||||
###############################################################################
|
###############################################################################
|
||||||
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
|
env "$DEVICE_AFFINITY_ENV=$GPU_PD" vllm serve "$MODEL" \
|
||||||
--gpu-memory-utilization 0.7 \
|
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_PD" \
|
||||||
--port "$PREFILL_DECODE_PORT" \
|
--port "$PREFILL_DECODE_PORT" \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--enable-request-id-headers \
|
--enable-request-id-headers \
|
||||||
--max-num-seqs 128 \
|
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||||
|
--max-model-len "$MAX_MODEL_LEN" \
|
||||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||||
--ec-transfer-config '{
|
--ec-transfer-config '{
|
||||||
"ec_connector": "ECExampleConnector",
|
"ec_connector": "ECExampleConnector",
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ torchaudio==2.10.0
|
|||||||
# These must be updated alongside torch
|
# These must be updated alongside torch
|
||||||
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||||
# FlashInfer should be updated together with the Dockerfile
|
# FlashInfer should be updated together with the Dockerfile
|
||||||
flashinfer-python==0.6.6
|
flashinfer-python==0.6.7
|
||||||
flashinfer-cubin==0.6.6
|
flashinfer-cubin==0.6.7
|
||||||
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
|
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
|
||||||
# breaking changes in 1.19.0
|
# breaking changes in 1.19.0
|
||||||
nvidia-cudnn-frontend>=1.13.0,<1.19.0
|
nvidia-cudnn-frontend>=1.13.0,<1.19.0
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
# --- Test Infrastructure ---
|
# --- Test Infrastructure ---
|
||||||
tblib
|
tblib
|
||||||
|
pytest
|
||||||
|
pytest_asyncio
|
||||||
pytest-timeout
|
pytest-timeout
|
||||||
pytest-cov
|
pytest-cov
|
||||||
pytest-forked
|
pytest-forked
|
||||||
@@ -7,8 +9,13 @@ pytest-rerunfailures
|
|||||||
pytest-shard
|
pytest-shard
|
||||||
|
|
||||||
# --- Core Tools & Bindings ---
|
# --- Core Tools & Bindings ---
|
||||||
|
|
||||||
absl-py
|
absl-py
|
||||||
|
accelerate
|
||||||
arctic-inference
|
arctic-inference
|
||||||
|
hf_transfer
|
||||||
|
lm_eval[api]
|
||||||
|
modelscope
|
||||||
|
|
||||||
# --- Audio Processing ---
|
# --- Audio Processing ---
|
||||||
librosa
|
librosa
|
||||||
|
|||||||
@@ -1,42 +1,730 @@
|
|||||||
# XPU Test Dependencies
|
|
||||||
# NOTE: Base image already has common.txt + xpu.txt installed,
|
|
||||||
# and vllm-openai stage has pytest, pytest-asyncio, lm-eval[api].
|
|
||||||
# This file only adds incremental test-specific packages.
|
|
||||||
|
|
||||||
# Additional test infrastructure (pytest/pytest-asyncio already in base)
|
|
||||||
# This file was autogenerated by uv via the following command:
|
# This file was autogenerated by uv via the following command:
|
||||||
# uv pip compile /workspace/vllm/requirements/xpu-test.in -o /workspace/vllm/requirements/xpu-test.txt -c /workspace/vllm/requirements/xpu.txt --index-strategy unsafe-best-match --extra-index-url ${PIP_EXTRA_INDEX_URL} --python-version ${PYTHON_VERSION}
|
# uv pip compile requirements/xpu-test.in -o requirements/xpu-test.txt -c requirements/xpu.txt --python-version 3.12 --index-strategy unsafe-best-match
|
||||||
tblib==3.1.0
|
absl-py==2.4.0
|
||||||
pytest-timeout==2.3.1
|
# via
|
||||||
pytest-cov==6.3.0
|
# -r requirements/xpu-test.in
|
||||||
pytest-forked==1.6.0
|
# rouge-score
|
||||||
pytest-rerunfailures==14.0
|
accelerate==1.13.0
|
||||||
pytest-shard==0.1.2
|
# via -r requirements/xpu-test.in
|
||||||
|
aiohappyeyeballs==2.6.1
|
||||||
arctic-inference==0.1.1
|
# via aiohttp
|
||||||
|
aiohttp==3.13.4
|
||||||
# Required for audio processing tests
|
# via
|
||||||
librosa==0.10.2.post1
|
# -c requirements/common.txt
|
||||||
audioread==3.0.1
|
# fsspec
|
||||||
soxr==0.5.0.post1
|
# gpt-oss
|
||||||
pooch==1.8.2
|
# lm-eval
|
||||||
soundfile==0.13.1
|
aiosignal==1.4.0
|
||||||
|
# via aiohttp
|
||||||
# Required for Mistral's streaming tool parser
|
|
||||||
blobfile==3.0.0
|
|
||||||
rapidfuzz==3.12.1
|
|
||||||
|
|
||||||
# Required for Mistral's streaming tool parser and some evaluation scripts
|
|
||||||
gpt-oss==0.0.8
|
|
||||||
schemathesis==3.39.15
|
|
||||||
jiwer==4.0.0
|
|
||||||
bm25s==0.2.13
|
|
||||||
pystemmer==3.0.0
|
|
||||||
mteb[bm25s]>=2, <3
|
|
||||||
num2words==0.5.14
|
|
||||||
pqdm==0.2.0
|
|
||||||
|
|
||||||
# Required for some evaluation scripts
|
|
||||||
timm==1.0.17
|
|
||||||
albumentations==1.4.6
|
albumentations==1.4.6
|
||||||
mistral-common[image,audio]==1.9.1
|
# via -r requirements/xpu-test.in
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
# via fastapi
|
||||||
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
|
anyio==4.13.0
|
||||||
|
# via
|
||||||
|
# httpx
|
||||||
|
# starlette
|
||||||
|
arctic-inference==0.1.1
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
attrs==26.1.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# jsonlines
|
||||||
|
# jsonschema
|
||||||
|
# referencing
|
||||||
|
audioread==3.0.1
|
||||||
|
# via
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
# librosa
|
||||||
|
blobfile==3.0.0
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
bm25s==0.2.13
|
||||||
|
# via
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
# mteb
|
||||||
|
bounded-pool-executor==0.0.3
|
||||||
|
# via pqdm
|
||||||
|
certifi==2026.2.25
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
cffi==2.0.0
|
||||||
|
# via soundfile
|
||||||
|
chardet==5.2.0
|
||||||
|
# via mbstrdecoder
|
||||||
|
charset-normalizer==3.4.6
|
||||||
|
# via requests
|
||||||
|
chz==0.4.0
|
||||||
|
# via gpt-oss
|
||||||
|
click==8.3.1
|
||||||
|
# via
|
||||||
|
# jiwer
|
||||||
|
# nltk
|
||||||
|
# schemathesis
|
||||||
|
# uvicorn
|
||||||
|
colorama==0.4.6
|
||||||
|
# via sacrebleu
|
||||||
|
coverage==7.13.5
|
||||||
|
# via pytest-cov
|
||||||
|
dataproperty==1.1.0
|
||||||
|
# via
|
||||||
|
# pytablewriter
|
||||||
|
# tabledata
|
||||||
|
datasets==4.8.4
|
||||||
|
# via
|
||||||
|
# evaluate
|
||||||
|
# lm-eval
|
||||||
|
# mteb
|
||||||
|
decorator==5.2.1
|
||||||
|
# via librosa
|
||||||
|
dill==0.4.1
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
# lm-eval
|
||||||
|
# multiprocess
|
||||||
|
docker==7.1.0
|
||||||
|
# via gpt-oss
|
||||||
|
docopt==0.6.2
|
||||||
|
# via num2words
|
||||||
|
dpcpp-cpp-rt==2025.3.1
|
||||||
|
# via
|
||||||
|
# onemkl-sycl-blas
|
||||||
|
# onemkl-sycl-dft
|
||||||
|
# onemkl-sycl-lapack
|
||||||
|
# onemkl-sycl-rng
|
||||||
|
# onemkl-sycl-sparse
|
||||||
|
# torch
|
||||||
|
evaluate==0.4.6
|
||||||
|
# via lm-eval
|
||||||
|
fastapi==0.135.2
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# gpt-oss
|
||||||
|
filelock==3.25.2
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# blobfile
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
# modelscope
|
||||||
|
# torch
|
||||||
|
# transformers
|
||||||
|
frozenlist==1.8.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# aiosignal
|
||||||
|
fsspec==2026.2.0
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
gpt-oss==0.0.8
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
graphql-core==3.2.8
|
||||||
|
# via hypothesis-graphql
|
||||||
|
h11==0.16.0
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# uvicorn
|
||||||
|
harfile==0.4.0
|
||||||
|
# via schemathesis
|
||||||
|
hf-transfer==0.1.9
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
hf-xet==1.4.2
|
||||||
|
# via huggingface-hub
|
||||||
|
html2text==2025.4.15
|
||||||
|
# via gpt-oss
|
||||||
|
httpcore==1.0.9
|
||||||
|
# via httpx
|
||||||
|
httpx==0.28.1
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# schemathesis
|
||||||
|
huggingface-hub==0.36.2
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
# sentence-transformers
|
||||||
|
# timm
|
||||||
|
# tokenizers
|
||||||
|
# transformers
|
||||||
|
hypothesis==6.151.10
|
||||||
|
# via
|
||||||
|
# hypothesis-graphql
|
||||||
|
# hypothesis-jsonschema
|
||||||
|
# schemathesis
|
||||||
|
hypothesis-graphql==0.12.0
|
||||||
|
# via schemathesis
|
||||||
|
hypothesis-jsonschema==0.23.1
|
||||||
|
# via schemathesis
|
||||||
|
idna==3.11
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
# yarl
|
||||||
|
imageio==2.37.3
|
||||||
|
# via scikit-image
|
||||||
|
impi-rt==2021.17.0
|
||||||
|
# via
|
||||||
|
# oneccl
|
||||||
|
# torch
|
||||||
|
iniconfig==2.3.0
|
||||||
|
# via pytest
|
||||||
|
intel-cmplr-lib-rt==2025.3.1
|
||||||
|
# via
|
||||||
|
# intel-sycl-rt
|
||||||
|
# torch
|
||||||
|
intel-cmplr-lib-ur==2025.3.1
|
||||||
|
# via
|
||||||
|
# intel-openmp
|
||||||
|
# intel-sycl-rt
|
||||||
|
# torch
|
||||||
|
intel-cmplr-lic-rt==2025.3.1
|
||||||
|
# via
|
||||||
|
# intel-opencl-rt
|
||||||
|
# intel-sycl-rt
|
||||||
|
# torch
|
||||||
|
intel-opencl-rt==2025.3.1
|
||||||
|
# via
|
||||||
|
# dpcpp-cpp-rt
|
||||||
|
# onemkl-sycl-blas
|
||||||
|
# onemkl-sycl-dft
|
||||||
|
# onemkl-sycl-lapack
|
||||||
|
# onemkl-sycl-rng
|
||||||
|
# onemkl-sycl-sparse
|
||||||
|
# torch
|
||||||
|
intel-openmp==2025.3.1
|
||||||
|
# via
|
||||||
|
# dpcpp-cpp-rt
|
||||||
|
# mkl
|
||||||
|
# torch
|
||||||
|
intel-pti==0.15.0
|
||||||
|
# via torch
|
||||||
|
intel-sycl-rt==2025.3.1
|
||||||
|
# via
|
||||||
|
# dpcpp-cpp-rt
|
||||||
|
# oneccl
|
||||||
|
# torch
|
||||||
|
jinja2==3.1.6
|
||||||
|
# via
|
||||||
|
# -c requirements/xpu.txt
|
||||||
|
# lm-eval
|
||||||
|
# torch
|
||||||
|
jiwer==4.0.0
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
joblib==1.5.3
|
||||||
|
# via
|
||||||
|
# librosa
|
||||||
|
# nltk
|
||||||
|
# scikit-learn
|
||||||
|
jsonlines==4.0.0
|
||||||
|
# via lm-eval
|
||||||
|
jsonschema==4.26.0
|
||||||
|
# via
|
||||||
|
# hypothesis-jsonschema
|
||||||
|
# mistral-common
|
||||||
|
# schemathesis
|
||||||
|
jsonschema-rs==0.45.0
|
||||||
|
# via schemathesis
|
||||||
|
jsonschema-specifications==2025.9.1
|
||||||
|
# via jsonschema
|
||||||
|
junit-xml==1.9
|
||||||
|
# via schemathesis
|
||||||
|
lazy-loader==0.5
|
||||||
|
# via
|
||||||
|
# librosa
|
||||||
|
# scikit-image
|
||||||
|
librosa==0.10.2.post1
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
llvmlite==0.44.0
|
||||||
|
# via numba
|
||||||
|
lm-eval==0.4.11
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
lxml==6.0.2
|
||||||
|
# via
|
||||||
|
# blobfile
|
||||||
|
# gpt-oss
|
||||||
|
# sacrebleu
|
||||||
|
markdown-it-py==4.0.0
|
||||||
|
# via rich
|
||||||
|
markupsafe==3.0.3
|
||||||
|
# via
|
||||||
|
# jinja2
|
||||||
|
# werkzeug
|
||||||
|
mbstrdecoder==1.1.4
|
||||||
|
# via
|
||||||
|
# dataproperty
|
||||||
|
# pytablewriter
|
||||||
|
# typepy
|
||||||
|
mdurl==0.1.2
|
||||||
|
# via markdown-it-py
|
||||||
|
mistral-common==1.10.0
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
mkl==2025.3.0
|
||||||
|
# via
|
||||||
|
# onemkl-sycl-blas
|
||||||
|
# onemkl-sycl-dft
|
||||||
|
# onemkl-sycl-lapack
|
||||||
|
# onemkl-sycl-rng
|
||||||
|
# onemkl-sycl-sparse
|
||||||
|
# torch
|
||||||
|
modelscope==1.35.3
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
more-itertools==10.8.0
|
||||||
|
# via lm-eval
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
msgpack==1.1.2
|
||||||
|
# via librosa
|
||||||
|
mteb==2.12.7
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
multidict==6.7.1
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
|
multiprocess==0.70.19
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
networkx==3.6.1
|
||||||
|
# via
|
||||||
|
# scikit-image
|
||||||
|
# torch
|
||||||
|
nltk==3.9.4
|
||||||
|
# via rouge-score
|
||||||
|
num2words==0.5.14
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
numba==0.61.2
|
||||||
|
# via
|
||||||
|
# -c requirements/xpu.txt
|
||||||
|
# librosa
|
||||||
|
numpy==2.2.6
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# albumentations
|
||||||
|
# bm25s
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
# imageio
|
||||||
|
# librosa
|
||||||
|
# lm-eval
|
||||||
|
# mistral-common
|
||||||
|
# mteb
|
||||||
|
# numba
|
||||||
|
# opencv-python-headless
|
||||||
|
# pandas
|
||||||
|
# pytrec-eval-terrier
|
||||||
|
# rouge-score
|
||||||
|
# sacrebleu
|
||||||
|
# scikit-image
|
||||||
|
# scikit-learn
|
||||||
|
# scipy
|
||||||
|
# sentence-transformers
|
||||||
|
# soundfile
|
||||||
|
# soxr
|
||||||
|
# tifffile
|
||||||
|
# torchvision
|
||||||
|
# transformers
|
||||||
|
oneccl==2021.17.1
|
||||||
|
# via
|
||||||
|
# oneccl-devel
|
||||||
|
# torch
|
||||||
|
oneccl-devel==2021.17.1
|
||||||
|
# via torch
|
||||||
|
onemkl-license==2025.3.0
|
||||||
|
# via
|
||||||
|
# mkl
|
||||||
|
# torch
|
||||||
|
onemkl-sycl-blas==2025.3.0
|
||||||
|
# via
|
||||||
|
# onemkl-sycl-lapack
|
||||||
|
# onemkl-sycl-sparse
|
||||||
|
# torch
|
||||||
|
onemkl-sycl-dft==2025.3.0
|
||||||
|
# via torch
|
||||||
|
onemkl-sycl-lapack==2025.3.0
|
||||||
|
# via torch
|
||||||
|
onemkl-sycl-rng==2025.3.0
|
||||||
|
# via torch
|
||||||
|
onemkl-sycl-sparse==2025.3.0
|
||||||
|
# via torch
|
||||||
|
openai-harmony==0.0.8
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# gpt-oss
|
||||||
|
opencv-python-headless==4.13.0.92
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# albumentations
|
||||||
|
# mistral-common
|
||||||
|
packaging==26.0
|
||||||
|
# via
|
||||||
|
# -c requirements/xpu.txt
|
||||||
|
# accelerate
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
# huggingface-hub
|
||||||
|
# lazy-loader
|
||||||
|
# modelscope
|
||||||
|
# pooch
|
||||||
|
# pytest
|
||||||
|
# pytest-rerunfailures
|
||||||
|
# scikit-image
|
||||||
|
# transformers
|
||||||
|
# typepy
|
||||||
|
pandas==3.0.1
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
pathvalidate==3.3.1
|
||||||
|
# via pytablewriter
|
||||||
|
pillow==12.1.1
|
||||||
|
# via
|
||||||
|
# imageio
|
||||||
|
# mistral-common
|
||||||
|
# scikit-image
|
||||||
|
# torchvision
|
||||||
|
platformdirs==4.9.4
|
||||||
|
# via pooch
|
||||||
|
pluggy==1.6.0
|
||||||
|
# via
|
||||||
|
# pytest
|
||||||
|
# pytest-cov
|
||||||
|
polars==1.39.3
|
||||||
|
# via mteb
|
||||||
|
polars-runtime-32==1.39.3
|
||||||
|
# via polars
|
||||||
|
pooch==1.8.2
|
||||||
|
# via
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
# librosa
|
||||||
|
portalocker==3.2.0
|
||||||
|
# via sacrebleu
|
||||||
|
pqdm==0.2.0
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
propcache==0.4.1
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
|
psutil==7.2.2
|
||||||
|
# via accelerate
|
||||||
|
py==1.11.0
|
||||||
|
# via pytest-forked
|
||||||
|
pyarrow==23.0.1
|
||||||
|
# via datasets
|
||||||
|
pycountry==26.2.16
|
||||||
|
# via pydantic-extra-types
|
||||||
|
pycparser==3.0
|
||||||
|
# via cffi
|
||||||
|
pycryptodomex==3.23.0
|
||||||
|
# via blobfile
|
||||||
|
pydantic==2.12.5
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# albumentations
|
||||||
|
# fastapi
|
||||||
|
# gpt-oss
|
||||||
|
# mistral-common
|
||||||
|
# mteb
|
||||||
|
# openai-harmony
|
||||||
|
# pydantic-extra-types
|
||||||
|
pydantic-core==2.41.5
|
||||||
|
# via pydantic
|
||||||
|
pydantic-extra-types==2.11.1
|
||||||
|
# via mistral-common
|
||||||
|
pyelftools==0.32
|
||||||
|
# via triton-xpu
|
||||||
|
pygments==2.20.0
|
||||||
|
# via
|
||||||
|
# pytest
|
||||||
|
# rich
|
||||||
|
pyrate-limiter==4.1.0
|
||||||
|
# via schemathesis
|
||||||
|
pystemmer==3.0.0
|
||||||
|
# via
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
# mteb
|
||||||
|
pytablewriter==1.2.1
|
||||||
|
# via lm-eval
|
||||||
|
pytest==9.0.2
|
||||||
|
# via
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
# pytest-asyncio
|
||||||
|
# pytest-cov
|
||||||
|
# pytest-forked
|
||||||
|
# pytest-rerunfailures
|
||||||
|
# pytest-shard
|
||||||
|
# pytest-timeout
|
||||||
|
# schemathesis
|
||||||
|
pytest-asyncio==1.3.0
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
pytest-cov==6.3.0
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
pytest-forked==1.6.0
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
pytest-rerunfailures==14.0
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
pytest-shard==0.1.2
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
pytest-timeout==2.3.1
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via
|
||||||
|
# pandas
|
||||||
|
# typepy
|
||||||
|
pytrec-eval-terrier==0.5.10
|
||||||
|
# via mteb
|
||||||
|
pytz==2026.1.post1
|
||||||
|
# via typepy
|
||||||
|
pyyaml==6.0.3
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# albumentations
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
# schemathesis
|
||||||
|
# timm
|
||||||
|
# transformers
|
||||||
|
rapidfuzz==3.12.1
|
||||||
|
# via
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
# jiwer
|
||||||
|
referencing==0.37.0
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# jsonschema-specifications
|
||||||
|
regex==2026.3.32
|
||||||
|
# via
|
||||||
|
# nltk
|
||||||
|
# sacrebleu
|
||||||
|
# tiktoken
|
||||||
|
# transformers
|
||||||
|
requests==2.33.1
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# datasets
|
||||||
|
# docker
|
||||||
|
# evaluate
|
||||||
|
# gpt-oss
|
||||||
|
# huggingface-hub
|
||||||
|
# lm-eval
|
||||||
|
# mistral-common
|
||||||
|
# modelscope
|
||||||
|
# mteb
|
||||||
|
# pooch
|
||||||
|
# schemathesis
|
||||||
|
# starlette-testclient
|
||||||
|
# tiktoken
|
||||||
|
# transformers
|
||||||
|
rich==14.3.3
|
||||||
|
# via
|
||||||
|
# mteb
|
||||||
|
# schemathesis
|
||||||
|
rouge-score==0.1.2
|
||||||
|
# via lm-eval
|
||||||
|
rpds-py==0.30.0
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# referencing
|
||||||
|
sacrebleu==2.6.0
|
||||||
|
# via lm-eval
|
||||||
|
safetensors==0.7.0
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# timm
|
||||||
|
# transformers
|
||||||
|
schemathesis==4.14.2
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
scikit-image==0.26.0
|
||||||
|
# via albumentations
|
||||||
|
scikit-learn==1.8.0
|
||||||
|
# via
|
||||||
|
# albumentations
|
||||||
|
# librosa
|
||||||
|
# lm-eval
|
||||||
|
# mteb
|
||||||
|
# sentence-transformers
|
||||||
|
scipy==1.17.1
|
||||||
|
# via
|
||||||
|
# albumentations
|
||||||
|
# bm25s
|
||||||
|
# librosa
|
||||||
|
# mteb
|
||||||
|
# pytrec-eval-terrier
|
||||||
|
# scikit-image
|
||||||
|
# scikit-learn
|
||||||
|
# sentence-transformers
|
||||||
|
sentence-transformers==5.3.0
|
||||||
|
# via mteb
|
||||||
|
setuptools==80.10.2
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# -c requirements/xpu.txt
|
||||||
|
# modelscope
|
||||||
|
# pytablewriter
|
||||||
|
# torch
|
||||||
|
six==1.17.0
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# junit-xml
|
||||||
|
# python-dateutil
|
||||||
|
# rouge-score
|
||||||
|
sortedcontainers==2.4.0
|
||||||
|
# via hypothesis
|
||||||
|
soundfile==0.13.1
|
||||||
|
# via
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
# librosa
|
||||||
|
# mistral-common
|
||||||
|
soxr==0.5.0.post1
|
||||||
|
# via
|
||||||
|
# -r requirements/xpu-test.in
|
||||||
|
# librosa
|
||||||
|
# mistral-common
|
||||||
|
sqlitedict==2.1.0
|
||||||
|
# via lm-eval
|
||||||
|
starlette==1.0.0
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# starlette-testclient
|
||||||
|
starlette-testclient==0.4.1
|
||||||
|
# via schemathesis
|
||||||
|
structlog==25.5.0
|
||||||
|
# via gpt-oss
|
||||||
|
sympy==1.14.0
|
||||||
|
# via torch
|
||||||
|
tabledata==1.3.4
|
||||||
|
# via pytablewriter
|
||||||
|
tabulate==0.10.0
|
||||||
|
# via sacrebleu
|
||||||
|
tbb==2022.3.0
|
||||||
|
# via
|
||||||
|
# intel-opencl-rt
|
||||||
|
# mkl
|
||||||
|
# torch
|
||||||
|
tblib==3.1.0
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
tcmlib==1.4.1
|
||||||
|
# via
|
||||||
|
# tbb
|
||||||
|
# torch
|
||||||
|
# umf
|
||||||
|
tcolorpy==0.1.7
|
||||||
|
# via pytablewriter
|
||||||
|
tenacity==9.1.4
|
||||||
|
# via
|
||||||
|
# gpt-oss
|
||||||
|
# lm-eval
|
||||||
|
# schemathesis
|
||||||
|
termcolor==3.3.0
|
||||||
|
# via gpt-oss
|
||||||
|
threadpoolctl==3.6.0
|
||||||
|
# via scikit-learn
|
||||||
|
tifffile==2026.3.3
|
||||||
|
# via scikit-image
|
||||||
|
tiktoken==0.12.0
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# gpt-oss
|
||||||
|
# lm-eval
|
||||||
|
# mistral-common
|
||||||
|
timm==1.0.17
|
||||||
|
# via -r requirements/xpu-test.in
|
||||||
|
tokenizers==0.22.2
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# transformers
|
||||||
|
torch==2.10.0+xpu
|
||||||
|
# via
|
||||||
|
# -c requirements/xpu.txt
|
||||||
|
# accelerate
|
||||||
|
# mteb
|
||||||
|
# sentence-transformers
|
||||||
|
# timm
|
||||||
|
# torchvision
|
||||||
|
torchvision==0.25.0+xpu
|
||||||
|
# via timm
|
||||||
|
tqdm==4.67.3
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
# huggingface-hub
|
||||||
|
# lm-eval
|
||||||
|
# modelscope
|
||||||
|
# mteb
|
||||||
|
# nltk
|
||||||
|
# pqdm
|
||||||
|
# sentence-transformers
|
||||||
|
# transformers
|
||||||
|
transformers==4.57.6
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# sentence-transformers
|
||||||
|
triton-xpu==3.6.0
|
||||||
|
# via torch
|
||||||
|
typepy==1.3.4
|
||||||
|
# via
|
||||||
|
# dataproperty
|
||||||
|
# pytablewriter
|
||||||
|
# tabledata
|
||||||
|
typing-extensions==4.15.0
|
||||||
|
# via
|
||||||
|
# -c requirements/common.txt
|
||||||
|
# aiosignal
|
||||||
|
# albumentations
|
||||||
|
# anyio
|
||||||
|
# chz
|
||||||
|
# fastapi
|
||||||
|
# huggingface-hub
|
||||||
|
# librosa
|
||||||
|
# lm-eval
|
||||||
|
# mistral-common
|
||||||
|
# mteb
|
||||||
|
# pqdm
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# pydantic-extra-types
|
||||||
|
# pytest-asyncio
|
||||||
|
# referencing
|
||||||
|
# schemathesis
|
||||||
|
# sentence-transformers
|
||||||
|
# starlette
|
||||||
|
# torch
|
||||||
|
# typing-inspection
|
||||||
|
typing-inspection==0.4.2
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# pydantic
|
||||||
|
umf==1.0.2
|
||||||
|
# via
|
||||||
|
# intel-cmplr-lib-ur
|
||||||
|
# torch
|
||||||
|
urllib3==2.6.3
|
||||||
|
# via
|
||||||
|
# blobfile
|
||||||
|
# docker
|
||||||
|
# modelscope
|
||||||
|
# requests
|
||||||
|
uvicorn==0.42.0
|
||||||
|
# via gpt-oss
|
||||||
|
werkzeug==3.1.7
|
||||||
|
# via schemathesis
|
||||||
|
word2number==1.1
|
||||||
|
# via lm-eval
|
||||||
|
xxhash==3.6.0
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# evaluate
|
||||||
|
yarl==1.23.0
|
||||||
|
# via aiohttp
|
||||||
|
zstandard==0.25.0
|
||||||
|
# via lm-eval
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import pytest
|
|||||||
|
|
||||||
from vllm.config import CompilationMode
|
from vllm.config import CompilationMode
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
|
||||||
|
|
||||||
from ...utils import compare_all_settings
|
from ...utils import compare_all_settings
|
||||||
|
|
||||||
@@ -109,10 +108,10 @@ def test_compile_correctness(
|
|||||||
tp_size = test_setting.tp_size
|
tp_size = test_setting.tp_size
|
||||||
attn_backend = test_setting.attn_backend
|
attn_backend = test_setting.attn_backend
|
||||||
method = test_setting.method
|
method = test_setting.method
|
||||||
if cuda_device_count_stateless() < pp_size * tp_size:
|
if current_platform.device_count() < pp_size * tp_size:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||||
f"{cuda_device_count_stateless()}"
|
f"{current_platform.device_count()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
final_args = [
|
final_args = [
|
||||||
|
|||||||
@@ -412,7 +412,7 @@ def test_cudagraph_sizes_post_init(
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
ctx,
|
ctx,
|
||||||
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
|
patch.object(current_platform, "device_count", return_value=tp_size),
|
||||||
):
|
):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if cudagraph_capture_sizes is not None:
|
if cudagraph_capture_sizes is not None:
|
||||||
@@ -577,48 +577,6 @@ def test_compile_sizes_padding_validation():
|
|||||||
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
|
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
|
|
||||||
[
|
|
||||||
# Normal capping: sizes filtered to <= num_blocks
|
|
||||||
(
|
|
||||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
|
|
||||||
512,
|
|
||||||
200,
|
|
||||||
[1, 2, 4, 8, 16, 32, 64, 128],
|
|
||||||
128,
|
|
||||||
),
|
|
||||||
# No capping needed: num_blocks >= max
|
|
||||||
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
|
|
||||||
# Exact boundary: num_blocks == max (no capping)
|
|
||||||
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
|
|
||||||
# All sizes capped: num_blocks < smallest size
|
|
||||||
([8, 16, 32], 32, 4, [], 0),
|
|
||||||
# num_blocks <= 0: early return, no change
|
|
||||||
([1, 2, 4], 4, 0, [1, 2, 4], 4),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_adjust_cudagraph_sizes_for_mamba_cache(
|
|
||||||
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
|
|
||||||
):
|
|
||||||
"""Test that cudagraph capture sizes are correctly capped to fit
|
|
||||||
available Mamba cache blocks.
|
|
||||||
|
|
||||||
See: https://github.com/vllm-project/vllm/issues/34094
|
|
||||||
"""
|
|
||||||
config = CompilationConfig(
|
|
||||||
cudagraph_capture_sizes=capture_sizes,
|
|
||||||
max_cudagraph_capture_size=max_size,
|
|
||||||
cudagraph_mode=CUDAGraphMode.NONE,
|
|
||||||
)
|
|
||||||
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
|
|
||||||
assert config.cudagraph_capture_sizes == expected_sizes
|
|
||||||
assert config.max_cudagraph_capture_size == expected_max
|
|
||||||
# Invariant: last element == max_cudagraph_capture_size
|
|
||||||
if expected_sizes:
|
|
||||||
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size
|
|
||||||
|
|
||||||
|
|
||||||
def test_inductor_asserts_default_disabled(monkeypatch):
|
def test_inductor_asserts_default_disabled(monkeypatch):
|
||||||
"""Test that inductor runtime asserts are disabled by default
|
"""Test that inductor runtime asserts are disabled by default
|
||||||
(INFO logging level) on torch < 2.12."""
|
(INFO logging level) on torch < 2.12."""
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import atexit
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
@@ -16,9 +18,20 @@ from vllm.utils.system_utils import update_environment_variables
|
|||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _distributed_worker_wrapper(fn, env, world_size, args, rank, skip_queue):
|
||||||
|
try:
|
||||||
|
fn(env, world_size, *args)
|
||||||
|
except BaseException as exc:
|
||||||
|
if isinstance(exc, pytest.skip.Exception):
|
||||||
|
skip_queue.put((rank, str(exc)))
|
||||||
|
return
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def distributed_run(fn, world_size, *args):
|
def distributed_run(fn, world_size, *args):
|
||||||
number_of_processes = world_size
|
number_of_processes = world_size
|
||||||
processes: list[mp.Process] = []
|
processes: list[mp.Process] = []
|
||||||
|
skip_queue: mp.SimpleQueue = mp.SimpleQueue()
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env: dict[str, str] = {}
|
env: dict[str, str] = {}
|
||||||
env["RANK"] = str(i)
|
env["RANK"] = str(i)
|
||||||
@@ -27,13 +40,32 @@ def distributed_run(fn, world_size, *args):
|
|||||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env["MASTER_ADDR"] = "localhost"
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env["MASTER_PORT"] = "12345"
|
env["MASTER_PORT"] = "12345"
|
||||||
p = mp.Process(target=fn, args=(env, world_size, *args))
|
p = mp.Process(
|
||||||
|
target=_distributed_worker_wrapper,
|
||||||
|
args=(fn, env, world_size, args, i, skip_queue),
|
||||||
|
)
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
|
skipped: list[tuple[int, str]] = []
|
||||||
|
while not skip_queue.empty():
|
||||||
|
rank, reason = skip_queue.get()
|
||||||
|
skipped.append((rank, reason))
|
||||||
|
|
||||||
|
if len(skipped) == number_of_processes:
|
||||||
|
reason = skipped[0][1]
|
||||||
|
pytest.skip(reason)
|
||||||
|
if 0 < len(skipped) < number_of_processes:
|
||||||
|
skipped_ranks = sorted(rank for rank, _ in skipped)
|
||||||
|
raise AssertionError(
|
||||||
|
"Distributed test had partial skips; expected either all ranks "
|
||||||
|
f"to skip or none. Skipped ranks: {skipped_ranks}, "
|
||||||
|
f"total ranks: {number_of_processes}"
|
||||||
|
)
|
||||||
|
|
||||||
for p in processes:
|
for p in processes:
|
||||||
assert p.exitcode == 0
|
assert p.exitcode == 0
|
||||||
|
|
||||||
@@ -48,7 +80,12 @@ def set_env_vars_and_device(env: dict[str, str]) -> None:
|
|||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
|
atexit.register(_destroy_process_group_if_initialized)
|
||||||
# Ensure each worker process has the same random seed
|
# Ensure each worker process has the same random seed
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
|
||||||
|
def _destroy_process_group_if_initialized() -> None:
|
||||||
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||||
from vllm.distributed.eplb.rebalance_execute import (
|
from vllm.distributed.eplb.rebalance_execute import (
|
||||||
move_from_buffer,
|
move_from_buffer,
|
||||||
rearrange_expert_weights_inplace,
|
rearrange_expert_weights_inplace,
|
||||||
@@ -130,9 +131,10 @@ def verify_expert_weights_after_shuffle(
|
|||||||
hidden_sizes: list[int],
|
hidden_sizes: list[int],
|
||||||
ep_rank: int,
|
ep_rank: int,
|
||||||
num_local_experts: int,
|
num_local_experts: int,
|
||||||
):
|
) -> bool:
|
||||||
"""Verify the weights after shuffling are correct."""
|
"""Verify the weights after shuffling are correct."""
|
||||||
num_layers = len(expert_weights)
|
num_layers = len(expert_weights)
|
||||||
|
ok = True
|
||||||
|
|
||||||
for layer in range(num_layers):
|
for layer in range(num_layers):
|
||||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||||
@@ -155,29 +157,38 @@ def verify_expert_weights_after_shuffle(
|
|||||||
dtype=actual_weights.dtype,
|
dtype=actual_weights.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(
|
if not torch.equal(actual_weights, expected_weights):
|
||||||
actual_weights,
|
ok = False
|
||||||
expected_weights,
|
actual_head = actual_weights[:8].detach().cpu().tolist()
|
||||||
msg=f"Layer {layer}, weight {weight_idx},"
|
expected_head = expected_weights[:8].detach().cpu().tolist()
|
||||||
f"local expert {local_expert}: "
|
print(
|
||||||
f"weights do not match. "
|
"verify_expert_weights_after_shuffle failed: "
|
||||||
f"Expected logical expert {expected_logical_expert}",
|
f"rank={ep_rank}, "
|
||||||
)
|
f"layer={layer}, weight_idx={weight_idx}, "
|
||||||
|
f"local_expert={local_expert}, "
|
||||||
|
f"expected_logical_expert={expected_logical_expert}, "
|
||||||
|
f"actual_head={actual_head}, expected_head={expected_head}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
def verify_redundant_experts_have_same_weights(
|
def verify_redundant_experts_have_same_weights(
|
||||||
expert_weights: list[list[torch.Tensor]],
|
expert_weights: list[list[torch.Tensor]],
|
||||||
indices: torch.Tensor,
|
indices: torch.Tensor,
|
||||||
hidden_sizes: list[int],
|
hidden_sizes: list[int],
|
||||||
|
ep_rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
num_local_experts: int,
|
num_local_experts: int,
|
||||||
):
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Verify that all replicas of the same logical expert have the same weights.
|
Verify that all replicas of the same logical expert have the same weights.
|
||||||
"""
|
"""
|
||||||
num_layers = len(expert_weights)
|
num_layers = len(expert_weights)
|
||||||
total_physical_experts = world_size * num_local_experts
|
total_physical_experts = world_size * num_local_experts
|
||||||
|
|
||||||
|
ok = True
|
||||||
for layer in range(num_layers):
|
for layer in range(num_layers):
|
||||||
# Collect weights for all physical experts for each weight matrix
|
# Collect weights for all physical experts for each weight matrix
|
||||||
all_weights: list[torch.Tensor] = []
|
all_weights: list[torch.Tensor] = []
|
||||||
@@ -227,14 +238,54 @@ def verify_redundant_experts_have_same_weights(
|
|||||||
# Verify that current physical expert's weights match the
|
# Verify that current physical expert's weights match the
|
||||||
# previously saved logical expert weights
|
# previously saved logical expert weights
|
||||||
for weight_idx in range(len(hidden_sizes)):
|
for weight_idx in range(len(hidden_sizes)):
|
||||||
torch.testing.assert_close(
|
if not torch.equal(
|
||||||
all_weights[weight_idx][physical_pos],
|
all_weights[weight_idx][physical_pos],
|
||||||
logical_expert_weights[logical_expert_id][weight_idx],
|
logical_expert_weights[logical_expert_id][weight_idx],
|
||||||
msg=f"Layer {layer}, weight {weight_idx},"
|
):
|
||||||
f"logical expert {logical_expert_id}: "
|
ok = False
|
||||||
f"Physical expert {physical_pos} has different weights"
|
actual_head = (
|
||||||
f"than expected",
|
all_weights[weight_idx][physical_pos][:8]
|
||||||
)
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
reference_head = (
|
||||||
|
logical_expert_weights[logical_expert_id][weight_idx][:8]
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"verify_redundant_experts_have_same_weights failed: "
|
||||||
|
f"rank={ep_rank}, "
|
||||||
|
f"layer={layer}, weight_idx={weight_idx}, "
|
||||||
|
f"logical_expert={logical_expert_id}, "
|
||||||
|
f"physical_pos={physical_pos}, "
|
||||||
|
f"actual_head={actual_head}, "
|
||||||
|
f"reference_head={reference_head}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def assert_verification_synced(local_ok: bool, msg: str) -> None:
|
||||||
|
ok_tensor = torch.tensor([1 if local_ok else 0], device="cuda", dtype=torch.int32)
|
||||||
|
torch.distributed.all_reduce(ok_tensor, op=torch.distributed.ReduceOp.MIN)
|
||||||
|
assert bool(ok_tensor.item()), msg
|
||||||
|
|
||||||
|
|
||||||
|
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
|
||||||
|
try:
|
||||||
|
return create_eplb_communicator(
|
||||||
|
group_coordinator=group_coordinator,
|
||||||
|
backend=backend,
|
||||||
|
expert_weights=expert_weights,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to create EPLB communicator for backend={backend}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
def _test_async_transfer_layer_without_mtp_worker(
|
def _test_async_transfer_layer_without_mtp_worker(
|
||||||
@@ -243,6 +294,7 @@ def _test_async_transfer_layer_without_mtp_worker(
|
|||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_local_experts: int,
|
num_local_experts: int,
|
||||||
num_logical_experts: int,
|
num_logical_experts: int,
|
||||||
|
eplb_communicator: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
set_env_vars_and_device(env)
|
set_env_vars_and_device(env)
|
||||||
|
|
||||||
@@ -254,8 +306,8 @@ def _test_async_transfer_layer_without_mtp_worker(
|
|||||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
)
|
)
|
||||||
|
|
||||||
tp_group = get_tp_group()
|
ep_group_coordinator = get_tp_group()
|
||||||
ep_group = tp_group.device_group
|
ep_group = ep_group_coordinator.device_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
device = torch.device(f"cuda:{ep_rank}")
|
device = torch.device(f"cuda:{ep_rank}")
|
||||||
|
|
||||||
@@ -298,6 +350,13 @@ def _test_async_transfer_layer_without_mtp_worker(
|
|||||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||||
cuda_stream = torch.cuda.Stream(device=device)
|
cuda_stream = torch.cuda.Stream(device=device)
|
||||||
|
|
||||||
|
communicator = create_eplb_communicator_or_raise(
|
||||||
|
group_coordinator=ep_group_coordinator,
|
||||||
|
backend=eplb_communicator,
|
||||||
|
expert_weights=expert_weights[0],
|
||||||
|
)
|
||||||
|
communicator.set_stream(cuda_stream)
|
||||||
|
|
||||||
for layer_idx in range(num_layers):
|
for layer_idx in range(num_layers):
|
||||||
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
|
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
|
||||||
transfer_layer(
|
transfer_layer(
|
||||||
@@ -306,6 +365,7 @@ def _test_async_transfer_layer_without_mtp_worker(
|
|||||||
expert_weights=expert_weights[layer_idx],
|
expert_weights=expert_weights[layer_idx],
|
||||||
expert_weights_buffer=expert_buffer,
|
expert_weights_buffer=expert_buffer,
|
||||||
ep_group=ep_group,
|
ep_group=ep_group,
|
||||||
|
communicator=communicator,
|
||||||
cuda_stream=cuda_stream,
|
cuda_stream=cuda_stream,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -320,24 +380,38 @@ def _test_async_transfer_layer_without_mtp_worker(
|
|||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
verify_expert_weights_after_shuffle(
|
local_ok = verify_expert_weights_after_shuffle(
|
||||||
expert_weights,
|
expert_weights,
|
||||||
new_indices,
|
new_indices,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
ep_rank,
|
ep_rank,
|
||||||
num_local_experts,
|
num_local_experts,
|
||||||
)
|
)
|
||||||
|
local_ok = (
|
||||||
verify_redundant_experts_have_same_weights(
|
verify_redundant_experts_have_same_weights(
|
||||||
expert_weights,
|
expert_weights,
|
||||||
new_indices,
|
new_indices,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
|
ep_rank,
|
||||||
world_size,
|
world_size,
|
||||||
num_local_experts,
|
num_local_experts,
|
||||||
)
|
)
|
||||||
|
and local_ok
|
||||||
|
)
|
||||||
|
assert_verification_synced(
|
||||||
|
local_ok,
|
||||||
|
"Async transfer verification failed on at least one rank. "
|
||||||
|
"See logs for details.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_rearrange_expert_weights_with_redundancy(
|
def _test_rearrange_expert_weights_with_redundancy(
|
||||||
env, world_size, num_layers, num_local_experts, num_logical_experts
|
env,
|
||||||
|
world_size,
|
||||||
|
num_layers,
|
||||||
|
num_local_experts,
|
||||||
|
num_logical_experts,
|
||||||
|
eplb_communicator: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||||
# to expert parallel)
|
# to expert parallel)
|
||||||
@@ -351,7 +425,8 @@ def _test_rearrange_expert_weights_with_redundancy(
|
|||||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group_coordinator = get_tp_group()
|
||||||
|
ep_group = ep_group_coordinator.cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
device = torch.device(f"cuda:{ep_rank}")
|
device = torch.device(f"cuda:{ep_rank}")
|
||||||
|
|
||||||
@@ -387,6 +462,12 @@ def _test_rearrange_expert_weights_with_redundancy(
|
|||||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
|
communicator = create_eplb_communicator_or_raise(
|
||||||
|
group_coordinator=ep_group_coordinator,
|
||||||
|
backend=eplb_communicator,
|
||||||
|
expert_weights=expert_weights[0],
|
||||||
|
)
|
||||||
|
|
||||||
# Execute weight rearrangement
|
# Execute weight rearrangement
|
||||||
rearrange_expert_weights_inplace(
|
rearrange_expert_weights_inplace(
|
||||||
old_indices,
|
old_indices,
|
||||||
@@ -394,24 +475,33 @@ def _test_rearrange_expert_weights_with_redundancy(
|
|||||||
expert_weights,
|
expert_weights,
|
||||||
ep_group,
|
ep_group,
|
||||||
is_profile=False,
|
is_profile=False,
|
||||||
|
communicator=communicator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the rearrangement result
|
# Verify the rearrangement result
|
||||||
verify_expert_weights_after_shuffle(
|
local_ok = verify_expert_weights_after_shuffle(
|
||||||
expert_weights,
|
expert_weights,
|
||||||
new_indices,
|
new_indices,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
ep_rank,
|
ep_rank,
|
||||||
num_local_experts,
|
num_local_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
local_ok = (
|
||||||
verify_redundant_experts_have_same_weights(
|
verify_redundant_experts_have_same_weights(
|
||||||
expert_weights,
|
expert_weights,
|
||||||
new_indices,
|
new_indices,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
|
ep_rank,
|
||||||
world_size,
|
world_size,
|
||||||
num_local_experts,
|
num_local_experts,
|
||||||
)
|
)
|
||||||
|
and local_ok
|
||||||
|
)
|
||||||
|
assert_verification_synced(
|
||||||
|
local_ok,
|
||||||
|
"Rearrange verification failed on at least one rank. See logs for details.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -437,8 +527,13 @@ def _test_rearrange_expert_weights_with_redundancy(
|
|||||||
(4, 8, 8, 16),
|
(4, 8, 8, 16),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
|
||||||
def test_rearrange_expert_weights_with_redundancy(
|
def test_rearrange_expert_weights_with_redundancy(
|
||||||
world_size, num_layers, num_local_experts, num_logical_experts
|
world_size,
|
||||||
|
num_layers,
|
||||||
|
num_local_experts,
|
||||||
|
num_logical_experts,
|
||||||
|
eplb_communicator,
|
||||||
):
|
):
|
||||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||||
|
|
||||||
@@ -450,6 +545,7 @@ def test_rearrange_expert_weights_with_redundancy(
|
|||||||
num_layers,
|
num_layers,
|
||||||
num_local_experts,
|
num_local_experts,
|
||||||
num_logical_experts,
|
num_logical_experts,
|
||||||
|
eplb_communicator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -464,7 +560,8 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
|||||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group_coordinator = get_tp_group()
|
||||||
|
ep_group = ep_group_coordinator.cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
device = torch.device(f"cuda:{ep_rank}")
|
device = torch.device(f"cuda:{ep_rank}")
|
||||||
|
|
||||||
@@ -494,24 +591,40 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
|||||||
layer_copy.append(weight.clone())
|
layer_copy.append(weight.clone())
|
||||||
original_weights.append(layer_copy)
|
original_weights.append(layer_copy)
|
||||||
|
|
||||||
|
communicator = create_eplb_communicator_or_raise(
|
||||||
|
group_coordinator=ep_group_coordinator,
|
||||||
|
backend="torch_nccl",
|
||||||
|
expert_weights=expert_weights[0],
|
||||||
|
)
|
||||||
|
|
||||||
# Execute rearrangement (should be no change)
|
# Execute rearrangement (should be no change)
|
||||||
rearrange_expert_weights_inplace(
|
rearrange_expert_weights_inplace(
|
||||||
indices,
|
indices,
|
||||||
indices, # Same indices
|
indices, # Same indices
|
||||||
expert_weights,
|
expert_weights,
|
||||||
ep_group,
|
ep_group,
|
||||||
|
communicator,
|
||||||
is_profile=False,
|
is_profile=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that the weights have not changed
|
# Verify that the weights have not changed
|
||||||
for layer in range(num_layers):
|
local_ok = True
|
||||||
for weight_idx in range(len(hidden_sizes)):
|
for layer in range(num_layers):
|
||||||
torch.testing.assert_close(
|
for weight_idx in range(len(hidden_sizes)):
|
||||||
expert_weights[layer][weight_idx],
|
if not torch.equal(
|
||||||
original_weights[layer][weight_idx],
|
expert_weights[layer][weight_idx],
|
||||||
msg=f"""Layer {layer}, weight {weight_idx}
|
original_weights[layer][weight_idx],
|
||||||
should remain unchanged""",
|
):
|
||||||
|
local_ok = False
|
||||||
|
print(
|
||||||
|
"test_rearrange_expert_weights_no_change failed: "
|
||||||
|
f"layer={layer}, weight_idx={weight_idx}",
|
||||||
|
flush=True,
|
||||||
)
|
)
|
||||||
|
assert_verification_synced(
|
||||||
|
local_ok,
|
||||||
|
"No-change EPLB verification failed on at least one rank.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -520,11 +633,13 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
|||||||
(2, 2, 2, 3),
|
(2, 2, 2, 3),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
|
||||||
def test_async_transfer_layer_without_mtp(
|
def test_async_transfer_layer_without_mtp(
|
||||||
world_size: int,
|
world_size: int,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_local_experts: int,
|
num_local_experts: int,
|
||||||
num_logical_experts: int,
|
num_logical_experts: int,
|
||||||
|
eplb_communicator: str,
|
||||||
):
|
):
|
||||||
"""Exercise async EPLB transfer path without MTP/spec decode."""
|
"""Exercise async EPLB transfer path without MTP/spec decode."""
|
||||||
|
|
||||||
@@ -537,6 +652,7 @@ def test_async_transfer_layer_without_mtp(
|
|||||||
num_layers,
|
num_layers,
|
||||||
num_local_experts,
|
num_local_experts,
|
||||||
num_logical_experts,
|
num_logical_experts,
|
||||||
|
eplb_communicator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -549,7 +665,10 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
|
|
||||||
if torch.accelerator.device_count() < world_size:
|
if torch.accelerator.device_count() < world_size:
|
||||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||||
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
|
distributed_run(
|
||||||
|
_test_rearrange_expert_weights_no_change,
|
||||||
|
world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||||
@@ -563,7 +682,8 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
|||||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
)
|
)
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
ep_group_coordinator = get_tp_group()
|
||||||
|
ep_group = ep_group_coordinator.cpu_group
|
||||||
ep_rank = torch.distributed.get_rank()
|
ep_rank = torch.distributed.get_rank()
|
||||||
device = torch.device(f"cuda:{ep_rank}")
|
device = torch.device(f"cuda:{ep_rank}")
|
||||||
|
|
||||||
@@ -600,23 +720,40 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
|||||||
layer_copy.append(weight.clone())
|
layer_copy.append(weight.clone())
|
||||||
original_weights.append(layer_copy)
|
original_weights.append(layer_copy)
|
||||||
|
|
||||||
|
communicator = create_eplb_communicator_or_raise(
|
||||||
|
group_coordinator=ep_group_coordinator,
|
||||||
|
backend="torch_nccl",
|
||||||
|
expert_weights=expert_weights[0],
|
||||||
|
)
|
||||||
|
|
||||||
# Execute profile mode rearrangement
|
# Execute profile mode rearrangement
|
||||||
rearrange_expert_weights_inplace(
|
rearrange_expert_weights_inplace(
|
||||||
old_indices,
|
old_indices,
|
||||||
new_indices,
|
new_indices,
|
||||||
expert_weights,
|
expert_weights,
|
||||||
ep_group,
|
ep_group,
|
||||||
|
communicator,
|
||||||
is_profile=True, # Profile mode
|
is_profile=True, # Profile mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# In profile mode, the weights should remain unchanged
|
# In profile mode, the weights should remain unchanged
|
||||||
for layer in range(num_layers):
|
local_ok = True
|
||||||
for weight_idx in range(len(hidden_sizes)):
|
for layer in range(num_layers):
|
||||||
torch.testing.assert_close(
|
for weight_idx in range(len(hidden_sizes)):
|
||||||
expert_weights[layer][weight_idx],
|
if not torch.equal(
|
||||||
original_weights[layer][weight_idx],
|
expert_weights[layer][weight_idx],
|
||||||
msg="In profile mode, the weights should remain unchanged",
|
original_weights[layer][weight_idx],
|
||||||
|
):
|
||||||
|
local_ok = False
|
||||||
|
print(
|
||||||
|
"test_rearrange_expert_weights_profile_mode failed: "
|
||||||
|
f"layer={layer}, weight_idx={weight_idx}",
|
||||||
|
flush=True,
|
||||||
)
|
)
|
||||||
|
assert_verification_synced(
|
||||||
|
local_ok,
|
||||||
|
"Profile-mode EPLB verification failed on at least one rank.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("world_size", [2, 4])
|
@pytest.mark.parametrize("world_size", [2, 4])
|
||||||
@@ -625,4 +762,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
|
|
||||||
if torch.accelerator.device_count() < world_size:
|
if torch.accelerator.device_count() < world_size:
|
||||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||||
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
|
distributed_run(
|
||||||
|
_test_rearrange_expert_weights_profile_mode,
|
||||||
|
world_size,
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from vllm.distributed.utils import StatelessProcessGroup
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.network_utils import get_open_port
|
from vllm.utils.network_utils import get_open_port
|
||||||
from vllm.utils.system_utils import update_environment_variables
|
from vllm.utils.system_utils import update_environment_variables
|
||||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
|
||||||
|
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
@@ -21,7 +20,7 @@ from ..utils import multi_gpu_test
|
|||||||
@ray.remote
|
@ray.remote
|
||||||
class _CUDADeviceCountStatelessTestActor:
|
class _CUDADeviceCountStatelessTestActor:
|
||||||
def get_count(self):
|
def get_count(self):
|
||||||
return cuda_device_count_stateless()
|
return current_platform.device_count()
|
||||||
|
|
||||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -1792,6 +1793,170 @@ async def test_tool_choice_validation_without_parser():
|
|||||||
assert "--tool-call-parser" in response_named.error.message
|
assert "--tool-call-parser" in response_named.error.message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_n_gt1_independent_tool_parsers():
|
||||||
|
"""n>1 streaming must use independent parser instances
|
||||||
|
and token-id histories per choice.
|
||||||
|
"""
|
||||||
|
mock_engine = MagicMock(spec=AsyncLLM)
|
||||||
|
mock_engine.errored = False
|
||||||
|
mock_engine.model_config = MockModelConfig()
|
||||||
|
mock_engine.input_processor = MagicMock()
|
||||||
|
mock_engine.io_processor = MagicMock()
|
||||||
|
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||||
|
|
||||||
|
models = OpenAIServingModels(
|
||||||
|
engine_client=mock_engine,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
|
)
|
||||||
|
openai_serving_render = _build_serving_render(mock_engine, models.registry)
|
||||||
|
|
||||||
|
serving_chat = OpenAIServingChat(
|
||||||
|
mock_engine,
|
||||||
|
models,
|
||||||
|
response_role="assistant",
|
||||||
|
openai_serving_render=openai_serving_render,
|
||||||
|
chat_template=CHAT_TEMPLATE,
|
||||||
|
chat_template_content_format="auto",
|
||||||
|
request_logger=None,
|
||||||
|
enable_auto_tools=True,
|
||||||
|
tool_parser="hermes",
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(MODEL_NAME)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
num_choices = 2
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
n=num_choices,
|
||||||
|
stream=True,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call_text = (
|
||||||
|
"<tool_call>\n"
|
||||||
|
'{"name": "get_weather", "arguments": {"city": "Tokyo"}}\n'
|
||||||
|
"</tool_call>"
|
||||||
|
)
|
||||||
|
all_token_ids = tokenizer.encode(tool_call_text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# Compute proper delta text for each token so that concatenated deltas
|
||||||
|
# reproduce the original string exactly.
|
||||||
|
steps: list[tuple[str, int]] = []
|
||||||
|
prev_decoded = ""
|
||||||
|
for i, tid in enumerate(all_token_ids):
|
||||||
|
decoded_so_far = tokenizer.decode(all_token_ids[: i + 1])
|
||||||
|
delta = decoded_so_far[len(prev_decoded) :]
|
||||||
|
steps.append((delta, tid))
|
||||||
|
prev_decoded = decoded_so_far
|
||||||
|
|
||||||
|
async def result_generator():
|
||||||
|
for delta_text, token_id in steps:
|
||||||
|
yield RequestOutput(
|
||||||
|
request_id="test-req",
|
||||||
|
prompt="test",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
index=choice_idx,
|
||||||
|
text=delta_text,
|
||||||
|
token_ids=[token_id],
|
||||||
|
cumulative_logprob=0.0,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
for choice_idx in range(num_choices)
|
||||||
|
],
|
||||||
|
finished=False,
|
||||||
|
)
|
||||||
|
# Final output with finish_reason
|
||||||
|
yield RequestOutput(
|
||||||
|
request_id="test-req",
|
||||||
|
prompt="test",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
index=choice_idx,
|
||||||
|
text="",
|
||||||
|
token_ids=[],
|
||||||
|
cumulative_logprob=0.0,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
for choice_idx in range(num_choices)
|
||||||
|
],
|
||||||
|
finished=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect tool-call deltas per choice from the SSE stream.
|
||||||
|
tc_deltas_by_choice: dict[int, list[dict]] = {i: [] for i in range(num_choices)}
|
||||||
|
async for chunk_str in serving_chat.chat_completion_stream_generator(
|
||||||
|
request=request,
|
||||||
|
result_generator=result_generator(),
|
||||||
|
request_id="test-req",
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
conversation=[],
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
request_metadata=RequestResponseMetadata(
|
||||||
|
request_id="test-req",
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
if not chunk_str.strip() or "data: [DONE]" in chunk_str:
|
||||||
|
continue
|
||||||
|
if chunk_str.startswith("data: "):
|
||||||
|
data = json.loads(chunk_str[6:].strip())
|
||||||
|
for choice in data.get("choices", []):
|
||||||
|
idx = choice["index"]
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
if delta.get("tool_calls"):
|
||||||
|
for tc in delta["tool_calls"]:
|
||||||
|
tc_deltas_by_choice[idx].append(tc)
|
||||||
|
|
||||||
|
# Both choices must independently produce the correct tool call.
|
||||||
|
for choice_idx in range(num_choices):
|
||||||
|
deltas = tc_deltas_by_choice[choice_idx]
|
||||||
|
assert len(deltas) > 0, (
|
||||||
|
f"Choice {choice_idx}: expected tool-call deltas but got none"
|
||||||
|
)
|
||||||
|
|
||||||
|
name = None
|
||||||
|
args_buf = ""
|
||||||
|
for tc in deltas:
|
||||||
|
fn = tc.get("function", {})
|
||||||
|
if fn.get("name"):
|
||||||
|
name = fn["name"]
|
||||||
|
if fn.get("arguments"):
|
||||||
|
args_buf += fn["arguments"]
|
||||||
|
|
||||||
|
assert name == "get_weather", (
|
||||||
|
f"Choice {choice_idx}: expected 'get_weather', got {name!r}"
|
||||||
|
)
|
||||||
|
parsed_args = json.loads(args_buf)
|
||||||
|
assert parsed_args == {"city": "Tokyo"}, (
|
||||||
|
f"Choice {choice_idx}: expected {{'city': 'Tokyo'}}, got {parsed_args}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCreateRemainingArgsDelta:
|
class TestCreateRemainingArgsDelta:
|
||||||
"""Tests for _create_remaining_args_delta helper function.
|
"""Tests for _create_remaining_args_delta helper function.
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,41 @@ import soundfile as sf
|
|||||||
|
|
||||||
from tests.entrypoints.openai.conftest import add_attention_backend
|
from tests.entrypoints.openai.conftest import add_attention_backend
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
SERVER_ARGS = ["--enforce-eager"]
|
SERVER_ARGS = ["--enforce-eager"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rocm_attention_config(model_name):
|
||||||
|
"""Return appropriate ROCm attention config for the given model.
|
||||||
|
|
||||||
|
Whisper uses cross-attention (ENCODER_DECODER) which ROCM_AITER_FA does
|
||||||
|
not support. For Whisper we use ROCM_AITER_UNIFIED_ATTN (or TRITON_ATTN
|
||||||
|
as fallback); other models can use ROCM_AITER_FA.
|
||||||
|
"""
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "whisper" in model_name.lower():
|
||||||
|
try:
|
||||||
|
from vllm.platforms.rocm import _ON_MI3XX
|
||||||
|
|
||||||
|
if _ON_MI3XX:
|
||||||
|
return {"backend": "ROCM_AITER_UNIFIED_ATTN"}
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"Could not import _ON_MI3XX from rocm platform, "
|
||||||
|
"falling back to TRITON_ATTN for Whisper."
|
||||||
|
)
|
||||||
|
return {"backend": "TRITON_ATTN"}
|
||||||
|
|
||||||
|
return {"backend": "ROCM_AITER_FA"}
|
||||||
|
|
||||||
|
|
||||||
def _get_server_args(attention_config):
|
def _get_server_args(attention_config):
|
||||||
"""Get server args with attention backend if specified."""
|
"""Get server args with attention backend if specified."""
|
||||||
args = SERVER_ARGS.copy()
|
args = SERVER_ARGS.copy()
|
||||||
@@ -30,10 +61,11 @@ def _get_server_args(attention_config):
|
|||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
|
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
|
||||||
)
|
)
|
||||||
def server(request, rocm_aiter_fa_attention):
|
def server(request):
|
||||||
# Parametrize over model name
|
# Parametrize over model name
|
||||||
|
attention_config = _get_rocm_attention_config(request.param)
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
request.param, _get_server_args(rocm_aiter_fa_attention)
|
request.param, _get_server_args(attention_config)
|
||||||
) as remote_server:
|
) as remote_server:
|
||||||
yield remote_server, request.param
|
yield remote_server, request.param
|
||||||
|
|
||||||
@@ -46,11 +78,12 @@ async def client_and_model(server):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
|
async def test_non_asr_model(foscolo):
|
||||||
# text to text model
|
# text to text model
|
||||||
model_name = "JackFram/llama-68m"
|
model_name = "JackFram/llama-68m"
|
||||||
|
attention_config = _get_rocm_attention_config(model_name)
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
model_name, _get_server_args(rocm_aiter_fa_attention)
|
model_name, _get_server_args(attention_config)
|
||||||
) as remote_server:
|
) as remote_server:
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
|
|
||||||
@@ -61,7 +94,7 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
|
async def test_basic_audio_with_lora(mary_had_lamb):
|
||||||
"""Ensure STT (translate) requests can pass LoRA through to generate."""
|
"""Ensure STT (translate) requests can pass LoRA through to generate."""
|
||||||
# ROCm SPECIFIC CONFIGURATION:
|
# ROCm SPECIFIC CONFIGURATION:
|
||||||
# To ensure the test passes on ROCm, we modify the max model length to 512.
|
# To ensure the test passes on ROCm, we modify the max model length to 512.
|
||||||
@@ -85,7 +118,7 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
|
|||||||
"1",
|
"1",
|
||||||
]
|
]
|
||||||
|
|
||||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
add_attention_backend(server_args, _get_rocm_attention_config(model_name))
|
||||||
|
|
||||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user