Compare commits
18 Commits
v0.18.2rc0
...
cmm
| Author | SHA1 | Date | |
|---|---|---|---|
| 013b73e9b2 | |||
| c77342da87 | |||
| 7f35bc4158 | |||
| 487dd34e04 | |||
| a15f86ecfa | |||
|
|
2a69949bda | ||
|
|
8adcf8c40a | ||
|
|
cfad6a509c | ||
|
|
c284a6671c | ||
|
|
3a30a1a6a8 | ||
|
|
29982d48b3 | ||
|
|
1dbbafd3f3 | ||
|
|
0ee3b7fc3d | ||
|
|
268bed9cf3 | ||
|
|
bcc0fdd0f3 | ||
|
|
69b8bd4b33 | ||
|
|
12449f9492 | ||
|
|
b92312dfd7 |
@@ -5,7 +5,6 @@ steps:
|
||||
depends_on: []
|
||||
device: amd_cpu
|
||||
no_plugin: true
|
||||
soft_fail: true
|
||||
commands:
|
||||
- >
|
||||
docker build
|
||||
@@ -21,3 +20,11 @@ steps:
|
||||
- docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
retry:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 1
|
||||
- exit_status: -10 # Agent was lost
|
||||
limit: 1
|
||||
- exit_status: 1 # Machine occasionally fail
|
||||
limit: 1
|
||||
|
||||
@@ -13,14 +13,12 @@ steps:
|
||||
- tests/kernels/attention/test_cpu_attn.py
|
||||
- tests/kernels/moe/test_cpu_fused_moe.py
|
||||
- tests/kernels/test_onednn.py
|
||||
- tests/kernels/test_awq_int4_to_int8.py
|
||||
commands:
|
||||
- |
|
||||
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
|
||||
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
|
||||
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
|
||||
pytest -x -v -s tests/kernels/test_onednn.py
|
||||
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
|
||||
pytest -x -v -s tests/kernels/test_onednn.py"
|
||||
|
||||
- label: CPU-Compatibility Tests
|
||||
depends_on: []
|
||||
|
||||
@@ -36,7 +36,6 @@
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"ignore-eos": "",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -128,4 +127,4 @@
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,6 @@
|
||||
"hf_split": "test",
|
||||
"no_stream": "",
|
||||
"no_oversample": "",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
|
||||
@@ -26,7 +26,6 @@
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"ignore-eos": "",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
|
||||
@@ -26,7 +26,6 @@
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"ignore-eos": "",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
|
||||
@@ -21,7 +21,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -48,7 +47,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -75,7 +73,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -103,7 +100,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -131,7 +127,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -156,7 +151,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -31,7 +30,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -49,7 +47,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
@@ -70,7 +67,6 @@
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"temperature": 0,
|
||||
"num_prompts": 200
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,29 +239,13 @@ fi
|
||||
# --- Docker housekeeping ---
|
||||
cleanup_docker
|
||||
|
||||
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin "$REGISTRY"
|
||||
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 936637512419.dkr.ecr.us-east-1.amazonaws.com
|
||||
|
||||
# --- Build or pull test image ---
|
||||
IMAGE="${IMAGE_TAG_XPU:-${image_name}}"
|
||||
|
||||
echo "Using image: ${IMAGE}"
|
||||
|
||||
if docker image inspect "${IMAGE}" >/dev/null 2>&1; then
|
||||
echo "Image already exists locally, skipping pull"
|
||||
if [[ -n "${IMAGE_TAG_XPU:-}" ]]; then
|
||||
echo "Using prebuilt XPU image: ${IMAGE_TAG_XPU}"
|
||||
docker pull "${IMAGE_TAG_XPU}"
|
||||
else
|
||||
echo "Image not found locally, waiting for lock..."
|
||||
|
||||
flock /tmp/docker-pull.lock bash -c "
|
||||
if docker image inspect '${IMAGE}' >/dev/null 2>&1; then
|
||||
echo 'Image already pulled by another runner'
|
||||
else
|
||||
echo 'Pulling image...'
|
||||
timeout 900 docker pull '${IMAGE}'
|
||||
fi
|
||||
"
|
||||
|
||||
echo "Pull step completed"
|
||||
echo "Using prebuilt XPU image: ${image_name}"
|
||||
docker pull "${image_name}"
|
||||
fi
|
||||
|
||||
remove_docker_container() {
|
||||
|
||||
@@ -2,6 +2,14 @@ group: Benchmarks
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: Benchmarks
|
||||
timeout_in_minutes: 20
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
source_file_dependencies:
|
||||
- benchmarks/
|
||||
commands:
|
||||
- bash scripts/run-benchmarks.sh
|
||||
|
||||
- label: Benchmarks CLI Test
|
||||
timeout_in_minutes: 20
|
||||
source_file_dependencies:
|
||||
|
||||
@@ -13,8 +13,8 @@ steps:
|
||||
- pytest -v -s distributed/test_eplb_algo.py
|
||||
- pytest -v -s distributed/test_eplb_utils.py
|
||||
|
||||
- label: EPLB Execution # 17min
|
||||
timeout_in_minutes: 27
|
||||
- label: EPLB Execution
|
||||
timeout_in_minutes: 20
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 4
|
||||
source_file_dependencies:
|
||||
|
||||
26
.github/CODEOWNERS
vendored
26
.github/CODEOWNERS
vendored
@@ -2,15 +2,14 @@
|
||||
# for more info about CODEOWNERS file
|
||||
|
||||
# This lists cover the "core" components of vLLM that require careful review
|
||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng @vadiklyutiy
|
||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng
|
||||
/vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni
|
||||
/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
|
||||
/vllm/model_executor/layers/mamba @tdoublep @tomeras91
|
||||
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516 @vadiklyutiy
|
||||
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
|
||||
/vllm/model_executor/layers/mamba @tdoublep
|
||||
/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516
|
||||
/vllm/model_executor/model_loader @22quinn
|
||||
/vllm/model_executor/layers/batch_invariant.py @yewentao256
|
||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
|
||||
@@ -48,9 +47,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/v1/attention @LucasWilkinson @MatthewBonanni
|
||||
/vllm/v1/attention/backend.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill
|
||||
/vllm/v1/attention/backends/mla @pavanimajety
|
||||
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety @vadiklyutiy
|
||||
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516 @vadiklyutiy
|
||||
/vllm/v1/attention/backends/gdn_attn.py @ZJY0516
|
||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
|
||||
/vllm/v1/sample @22quinn @houseroad @njhill
|
||||
/vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni
|
||||
@@ -72,7 +71,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
|
||||
/tests/evals @mgoin @vadiklyutiy
|
||||
/tests/evals @mgoin
|
||||
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
@@ -83,7 +82,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery
|
||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||
/tests/lora @jeejeelee
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep @tomeras91
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||
/tests/v1/kv_connector @ApostaC @orozery
|
||||
/tests/v1/kv_offload @ApostaC @orozery
|
||||
@@ -127,14 +126,9 @@ mkdocs.yaml @hmellor
|
||||
/vllm/platforms/xpu.py @jikunshang
|
||||
/docker/Dockerfile.xpu @jikunshang
|
||||
|
||||
# Nemotron-specific files
|
||||
/vllm/model_executor/models/*nemotron* @tomeras91
|
||||
/vllm/transformers_utils/configs/*nemotron* @tomeras91
|
||||
/tests/**/*nemotron* @tomeras91
|
||||
|
||||
# Qwen-specific files
|
||||
/vllm/model_executor/models/qwen* @sighingnow @vadiklyutiy
|
||||
/vllm/transformers_utils/configs/qwen* @sighingnow @vadiklyutiy
|
||||
/vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow
|
||||
/vllm/model_executor/models/qwen* @sighingnow
|
||||
|
||||
# MTP-specific files
|
||||
/vllm/model_executor/models/deepseek_mtp.py @luccafong
|
||||
@@ -150,7 +144,7 @@ mkdocs.yaml @hmellor
|
||||
# Kernels
|
||||
/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep
|
||||
/vllm/v1/attention/ops/triton_unified_attention.py @tdoublep
|
||||
/vllm/model_executor/layers/fla @ZJY0516 @vadiklyutiy
|
||||
/vllm/model_executor/layers/fla @ZJY0516
|
||||
|
||||
# ROCm related: specify owner with write access to notify AMD folks for careful code review
|
||||
/vllm/**/*rocm* @tjtanaa
|
||||
|
||||
420
CMakeLists.txt
420
CMakeLists.txt
@@ -309,7 +309,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
set(CUTLASS_REVISION "v4.4.2")
|
||||
set(CUTLASS_REVISION "v4.2.1")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@@ -340,6 +340,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
@@ -489,6 +490,132 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
|
||||
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||
# kernels for the remaining archs that are not already built for 3x.
|
||||
# (Build 8.9 for FP8)
|
||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
|
||||
# subtract out the archs that are already built for 3x
|
||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||
if (SCALED_MM_2X_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
||||
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
||||
else()
|
||||
if (SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
||||
" for and covered by scaled_mm_c3x")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
@@ -566,6 +693,55 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(MLA_ARCHS)
|
||||
endif()
|
||||
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
|
||||
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
@@ -611,6 +787,36 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
|
||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building moe_data as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
@@ -758,9 +964,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
|
||||
#
|
||||
set(VLLM_STABLE_EXT_SRC
|
||||
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu")
|
||||
"csrc/libtorch_stable/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
@@ -775,209 +979,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch)
|
||||
#
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||
# kernels for the remaining archs that are not already built for 3x.
|
||||
# (Build 8.9 for FP8)
|
||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||
"7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}")
|
||||
# subtract out the archs that are already built for 3x
|
||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||
if (SCALED_MM_2X_ARCHS)
|
||||
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
||||
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
||||
else()
|
||||
if (SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
||||
" for and covered by scaled_mm_c3x")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS MoE kernels (moved from _C to _C_stable_libtorch)
|
||||
#
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
|
||||
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building moe_data as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C_stable extension.")
|
||||
define_extension_target(
|
||||
_C_stable_libtorch
|
||||
@@ -986,7 +987,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SOURCES ${VLLM_STABLE_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
@@ -1000,10 +1000,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Needed to use cuda APIs from C-shim
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
USE_CUDA)
|
||||
|
||||
# Needed by CUTLASS kernels
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
endif()
|
||||
|
||||
#
|
||||
|
||||
@@ -373,7 +373,6 @@ if (ENABLE_X86_ISA)
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
|
||||
|
||||
@@ -117,14 +117,6 @@ inline void parallel_for(int n, const func_t& f) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int get_thread_num() {
|
||||
#if defined(_OPENMP)
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
|
||||
@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
|
||||
template <typename T> inline bool can_use_brgemm(int M);
|
||||
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
|
||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<uint8_t>(int M) { return M > 4; }
|
||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
|
||||
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
|
||||
|
||||
@@ -40,17 +40,9 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
inline int64_t get_4bit_block_k_size(int64_t group_size) {
|
||||
return group_size > 128 ? 128 : group_size;
|
||||
}
|
||||
|
||||
// pack weight into vnni format
|
||||
// pack weight to vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// pack weight to vnni format for int4 (adapted from sglang)
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
convert_weight_packed_scale_zp(at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
@@ -241,31 +233,6 @@ void tinygemm_kernel(
|
||||
int64_t strideBs,
|
||||
bool brg);
|
||||
|
||||
// int4 scaled GEMM (adapted from sglang)
|
||||
at::Tensor int4_scaled_mm_cpu(
|
||||
at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, at::Tensor& w_scales, std::optional<at::Tensor> bias);
|
||||
|
||||
// int4 tinygemm kernel interface(adapted from sglang)
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
scalar_t* C,
|
||||
float* C_temp,
|
||||
const uint8_t* A,
|
||||
const float* scales_a,
|
||||
const int32_t* qzeros_a,
|
||||
const uint8_t* B,
|
||||
const float* scales_b,
|
||||
const int8_t* qzeros_b,
|
||||
const int32_t* compensation,
|
||||
int8_t* dqB_tmp,
|
||||
int64_t M,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldc_f,
|
||||
int64_t ldc_s,
|
||||
bool store_out,
|
||||
bool use_brgemm);
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
inline void print_16x32i(const __m512i x) {
|
||||
int32_t a[16];
|
||||
|
||||
@@ -1,755 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Adapted from sgl-project/sglang
|
||||
// https://github.com/sgl-project/sglang/pull/8226
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
#define BLOCK_N block_size_n()
|
||||
#define BLOCK_M 128
|
||||
|
||||
template <bool sym_quant_act>
|
||||
struct ActDtype;
|
||||
template <>
|
||||
struct ActDtype<true> {
|
||||
using type = int8_t;
|
||||
};
|
||||
template <>
|
||||
struct ActDtype<false> {
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
struct alignas(32) m256i_wrapper {
|
||||
__m256i data;
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline std::array<m256i_wrapper, 2> load_zps_4vnni(
|
||||
const int8_t* __restrict__ zps) {
|
||||
__m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps));
|
||||
__m256i vzps_high =
|
||||
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(zps + 8));
|
||||
__m256i shuffle_mask =
|
||||
_mm256_set_epi8(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3,
|
||||
3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
|
||||
vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask);
|
||||
vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask);
|
||||
m256i_wrapper vzps_low_wp, vzps_high_wp;
|
||||
vzps_low_wp.data = vzps_low;
|
||||
vzps_high_wp.data = vzps_high;
|
||||
return {vzps_low_wp, vzps_high_wp};
|
||||
}
|
||||
|
||||
inline std::array<m256i_wrapper, 2> load_uint4_as_int8(
|
||||
const uint8_t* __restrict__ qB) {
|
||||
__m256i packed = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(qB));
|
||||
const __m256i low_mask = _mm256_set1_epi8(0x0f);
|
||||
__m256i high = _mm256_srli_epi16(packed, 4);
|
||||
high = _mm256_and_si256(high, low_mask);
|
||||
__m256i low = _mm256_and_si256(packed, low_mask);
|
||||
m256i_wrapper low_wp, high_wp;
|
||||
low_wp.data = low;
|
||||
high_wp.data = high;
|
||||
return {low_wp, high_wp};
|
||||
}
|
||||
|
||||
template <int N, int ldb>
|
||||
void _dequant_weight_zp_only(const uint8_t* __restrict__ B, int8_t* dqB,
|
||||
const int8_t* __restrict__ qzeros, int64_t K) {
|
||||
#pragma GCC unroll 2
|
||||
for (int n = 0; n < N; n += 16) {
|
||||
auto [zps_low_wp, zps_high_wp] = load_zps_4vnni(&qzeros[n]);
|
||||
auto zps_low = zps_low_wp.data;
|
||||
auto zps_high = zps_high_wp.data;
|
||||
for (int k = 0; k < K; k += 4) {
|
||||
auto [vb_low_wp, vb_high_wp] =
|
||||
load_uint4_as_int8(B + ldb * k + n / 2 * 4);
|
||||
auto vb_low = vb_low_wp.data;
|
||||
auto vb_high = vb_high_wp.data;
|
||||
vb_high = _mm256_sub_epi8(vb_high, zps_high);
|
||||
vb_low = _mm256_sub_epi8(vb_low, zps_low);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4),
|
||||
vb_low);
|
||||
_mm256_storeu_si256(
|
||||
reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool sym_quant_act, int N, bool accum>
|
||||
void _dequant_and_store(float* __restrict__ output,
|
||||
const int32_t* __restrict__ input,
|
||||
const float* __restrict__ scale_a,
|
||||
const int32_t* __restrict__ zp_a,
|
||||
const float* __restrict__ scale_b,
|
||||
const int32_t* __restrict__ comp_b, int M, int ldi,
|
||||
int ldo, int ldsa = 1) {
|
||||
for (int m = 0; m < M; ++m) {
|
||||
float a_scale = *(scale_a + m * ldsa);
|
||||
__m512 va_scale = _mm512_set1_ps(a_scale);
|
||||
int32_t a_zp;
|
||||
__m512i va_zp;
|
||||
if constexpr (!sym_quant_act) {
|
||||
a_zp = *(zp_a + m * ldsa);
|
||||
va_zp = _mm512_set1_epi32(a_zp);
|
||||
}
|
||||
int n = 0;
|
||||
#pragma GCC unroll 2
|
||||
for (; n < N; n += 16) {
|
||||
__m512i vc = _mm512_loadu_si512(input + m * ldi + n);
|
||||
if constexpr (!sym_quant_act) {
|
||||
__m512i vb_comp = _mm512_loadu_si512(comp_b + n);
|
||||
vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp));
|
||||
}
|
||||
__m512 vc_f = _mm512_cvtepi32_ps(vc);
|
||||
__m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale);
|
||||
__m512 vb_s = _mm512_loadu_ps(scale_b + n);
|
||||
vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s);
|
||||
if constexpr (accum) {
|
||||
__m512 vo = _mm512_loadu_ps(output + m * ldo + n);
|
||||
_mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul));
|
||||
} else {
|
||||
_mm512_storeu_ps(output + m * ldo + n, vc_f_mul);
|
||||
}
|
||||
}
|
||||
for (; n < N; ++n) {
|
||||
float dq_val;
|
||||
if constexpr (sym_quant_act) {
|
||||
dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n];
|
||||
} else {
|
||||
dq_val = (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale *
|
||||
scale_b[n];
|
||||
}
|
||||
if constexpr (accum) {
|
||||
output[m * ldo + n] += dq_val;
|
||||
} else {
|
||||
output[m * ldo + n] = dq_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
template <int N, int ldb>
|
||||
void _dequant_weight_zp_only(const uint8_t* B, int8_t* dqB,
|
||||
const int8_t* qzeros, int64_t K) {
|
||||
for (int k = 0; k < K; ++k) {
|
||||
for (int n = 0; n < N / 2; ++n) {
|
||||
int32_t b = (int32_t)B[k * ldb + n];
|
||||
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n];
|
||||
dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline __m512i combine_m256i(__m256i a, __m256i b) {
|
||||
__m512i c = _mm512_castsi256_si512(a);
|
||||
return _mm512_inserti64x4(c, b, 1);
|
||||
}
|
||||
|
||||
inline __m512i combine_m256i(std::array<m256i_wrapper, 2> two_256) {
|
||||
return combine_m256i(two_256[0].data, two_256[1].data);
|
||||
}
|
||||
|
||||
static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) {
|
||||
__m512i zero = _mm512_setzero_si512();
|
||||
__mmask64 blt0 = _mm512_movepi8_mask(b);
|
||||
return _mm512_mask_sub_epi8(a, blt0, zero, a);
|
||||
}
|
||||
|
||||
template <bool sym_quant_act, int M, int N, int ldb>
|
||||
void _dequant_gemm_accum_small_M(float* __restrict__ C, const uint8_t* A,
|
||||
const float* scales_a, const int32_t* qzeros_a,
|
||||
const uint8_t* B, const float* scales_b,
|
||||
const int8_t* qzeros_b, int64_t K, int64_t lda,
|
||||
int64_t ldc) {
|
||||
constexpr int COLS = N / 16;
|
||||
__m512i ones = _mm512_set1_epi8(1);
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[M * COLS];
|
||||
__m512 vscales[COLS];
|
||||
__m512i vzps[COLS];
|
||||
__m512i vcompensate[COLS];
|
||||
|
||||
Unroll<COLS>{}([&](auto i) {
|
||||
vscales[i] = _mm512_loadu_ps(scales_b + i * 16);
|
||||
vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16));
|
||||
if constexpr (!sym_quant_act) {
|
||||
vcompensate[i] = _mm512_setzero_epi32();
|
||||
}
|
||||
});
|
||||
Unroll<M * COLS>{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); });
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr const int row = i / COLS;
|
||||
constexpr const int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k));
|
||||
}
|
||||
|
||||
if constexpr (row == 0) {
|
||||
int B_offset = k * ldb + col * 16 * 2;
|
||||
vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset));
|
||||
vb[col] = _mm512_sub_epi8(vb[col], vzps[col]);
|
||||
if constexpr (!sym_quant_act) {
|
||||
vcompensate[col] = _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]);
|
||||
}
|
||||
_mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0);
|
||||
}
|
||||
if constexpr (sym_quant_act) {
|
||||
auto vsb = _mm512_sign_epi8(vb[col], va);
|
||||
auto vabsa = _mm512_sign_epi8(va, va);
|
||||
vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb);
|
||||
} else {
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr const int unroll = 4;
|
||||
int k = 0;
|
||||
for (; k < K / 4 / unroll; k++) {
|
||||
Unroll<unroll>{}(
|
||||
[&](auto i) { Unroll<M * COLS>{}(compute, 4 * (k * unroll + i)); });
|
||||
}
|
||||
k *= 4 * unroll;
|
||||
for (; k < K; k += 4) {
|
||||
Unroll<M * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto store = [&](auto i) {
|
||||
constexpr const int row = i / COLS;
|
||||
constexpr const int col = i % COLS;
|
||||
__m512 vc_float;
|
||||
if constexpr (!sym_quant_act) {
|
||||
vc[i] = _mm512_sub_epi32(
|
||||
vc[i], _mm512_mullo_epi32(vcompensate[col],
|
||||
_mm512_set1_epi32(*(qzeros_a + row))));
|
||||
}
|
||||
vc_float = _mm512_cvtepi32_ps(vc[i]);
|
||||
vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row)));
|
||||
|
||||
vc_float = _mm512_mul_ps(vc_float, vscales[col]);
|
||||
auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16);
|
||||
vc_float = _mm512_add_ps(vc_float, vc_old);
|
||||
_mm512_storeu_ps(C + row * ldc + col * 16, vc_float);
|
||||
};
|
||||
Unroll<M * COLS>{}(store);
|
||||
}
|
||||
|
||||
#define CALL_DEQUANT_GEMM_ACCUM_SMALL_M(M) \
|
||||
_dequant_gemm_accum_small_M<sym_quant_act, M, N, ldb>( \
|
||||
C, A, scales_a, qzeros_a, B, scales_b, qzeros_b, K, lda, ldc);
|
||||
#endif
|
||||
|
||||
template <bool sym_quant_act, int N, int ldb>
|
||||
void _dequant_gemm_accum(float* C, const uint8_t* A, const float* scales_a,
|
||||
const int32_t* qzeros_a, const uint8_t* B,
|
||||
const float* scales_b, const int8_t* qzeros_b,
|
||||
const int32_t* compensation, int8_t* dqB, int64_t M,
|
||||
int64_t K, int64_t lda, int64_t ldc, bool use_brgemm) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
if (!use_brgemm) {
|
||||
switch (M) {
|
||||
case 1:
|
||||
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(1);
|
||||
break;
|
||||
case 2:
|
||||
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(2);
|
||||
break;
|
||||
case 3:
|
||||
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(3);
|
||||
break;
|
||||
case 4:
|
||||
CALL_DEQUANT_GEMM_ACCUM_SMALL_M(4);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "tinygemm_kernel: unexpected M for AVX path!");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
_dequant_weight_zp_only<N, ldb>(B, dqB, qzeros_b, K);
|
||||
using Tin = typename ActDtype<sym_quant_act>::type;
|
||||
Tin* A_ptr = (Tin*)A;
|
||||
if (use_brgemm) {
|
||||
int32_t C_i32[M * N];
|
||||
at::native::cpublas::brgemm(M, N, K, lda, N /*ldb*/, N /*ldc*/,
|
||||
false /* add_C */, A_ptr, dqB, C_i32,
|
||||
true /* is_vnni */);
|
||||
_mm_prefetch(B + N * K / 2, _MM_HINT_T0);
|
||||
_mm_prefetch(A + K, _MM_HINT_T0);
|
||||
_dequant_and_store<sym_quant_act, N, true>(C, C_i32, scales_a, qzeros_a,
|
||||
scales_b, compensation, M,
|
||||
N /*ldi*/, ldc, 1 /*ldsa*/);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
TORCH_CHECK(false, "tinygemm_kernel: scalar path not implemented!");
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) {
|
||||
if (bias_ptr) {
|
||||
for (int i = 0; i < m; ++i) {
|
||||
int j = 0;
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 bias_vec = _mm512_loadu_ps(bias_ptr + j);
|
||||
_mm512_storeu_ps(y_buf + i * N + j, bias_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
y_buf[i * N + j] = bias_ptr[j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < m; ++i) {
|
||||
int j = 0;
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 zero_vec = _mm512_setzero_ps();
|
||||
_mm512_storeu_ps(y_buf + i * N + j, zero_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
y_buf[i * N + j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename out_dtype>
|
||||
inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m,
|
||||
int64_t lda) {
|
||||
for (int i = 0; i < m; ++i) {
|
||||
int j = 0;
|
||||
if constexpr (std::is_same<out_dtype, float>::value) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||
_mm512_storeu_ps(c_ptr + i * lda + j, y_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
c_ptr[i * lda + j] = y_buf[i * N + j];
|
||||
}
|
||||
} else if constexpr (std::is_same<out_dtype, at::BFloat16>::value) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||
__m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
|
||||
y_bf16_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]);
|
||||
}
|
||||
} else if constexpr (std::is_same<out_dtype, at::Half>::value) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#pragma GCC unroll 2
|
||||
for (; j < N; j += 16) {
|
||||
__m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j);
|
||||
__m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec);
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j),
|
||||
y_fp16_vec);
|
||||
}
|
||||
#endif
|
||||
for (; j < N; ++j) {
|
||||
c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output dtype");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fill_val_stub(int32_t* __restrict__ output, int32_t value, int64_t size) {
|
||||
using iVec = at::vec::Vectorized<int32_t>;
|
||||
constexpr int VecSize = iVec::size();
|
||||
const iVec fill_val_vec = iVec(value);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - VecSize; d += VecSize) {
|
||||
fill_val_vec.store(output + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
output[d] = value;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool sym_quant_act, typename act_dtype, typename out_dtype>
|
||||
void _da8w4_linear_impl(
|
||||
act_dtype* __restrict__ input, const float* __restrict__ input_scales,
|
||||
const int32_t* __restrict__ input_qzeros,
|
||||
const uint8_t* __restrict__ weight, const float* __restrict__ weight_scales,
|
||||
const int8_t* __restrict__ weight_qzeros, const float* __restrict__ bias,
|
||||
out_dtype* __restrict__ output, float* __restrict__ output_temp,
|
||||
int8_t* __restrict__ dequant_weight_temp, int64_t M, int64_t N, int64_t K,
|
||||
int64_t num_groups) {
|
||||
const bool use_brgemm = can_use_brgemm<act_dtype>(M);
|
||||
int64_t block_m = [&]() -> long {
|
||||
if (M <= 48) {
|
||||
return M;
|
||||
} else if (M < 64) {
|
||||
return 32;
|
||||
} else if (M < 96) {
|
||||
return 64;
|
||||
} else {
|
||||
return 128;
|
||||
}
|
||||
}();
|
||||
int64_t Mc = div_up(M, block_m);
|
||||
bool parallel_on_M = M > 128;
|
||||
int64_t Nc = N / BLOCK_N;
|
||||
int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc;
|
||||
int64_t group_size = div_up(K, num_groups);
|
||||
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||
int64_t Kc = K / _block_k;
|
||||
int64_t block_per_group = group_size / _block_k;
|
||||
|
||||
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
|
||||
int tid = get_thread_num();
|
||||
float* C_tmp = output_temp + tid * block_m * BLOCK_N;
|
||||
int8_t* dqB_tmp = dequant_weight_temp + tid * _block_k * BLOCK_N;
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
int64_t mc = parallel_on_M ? i / Nc : 0;
|
||||
int64_t nc = parallel_on_M ? i % Nc : i;
|
||||
int64_t mc_end = parallel_on_M ? mc + 1 : Mc;
|
||||
|
||||
for (int mci = mc; mci < mc_end; ++mci) {
|
||||
int64_t m_size =
|
||||
mci * block_m + block_m > M ? M - mci * block_m : block_m;
|
||||
auto bias_data = bias ? bias + nc * BLOCK_N : nullptr;
|
||||
copy_bias<BLOCK_N>(bias_data, C_tmp, m_size);
|
||||
for (int kci = 0; kci < Kc; ++kci) {
|
||||
int32_t* compensation_ptr =
|
||||
sym_quant_act
|
||||
? nullptr
|
||||
: (int32_t*)(void*)(weight +
|
||||
(nc * Kc + kci) *
|
||||
(BLOCK_N *
|
||||
(_block_k / 2 + sizeof(int32_t))) +
|
||||
_block_k * BLOCK_N / 2);
|
||||
_dequant_gemm_accum<sym_quant_act, BLOCK_N, BLOCK_N / 2>(
|
||||
/*C*/ C_tmp,
|
||||
/*A*/ (uint8_t*)input + mci * block_m * K + kci * _block_k,
|
||||
/*scales_a*/ input_scales + mci * block_m,
|
||||
/*qzeros_a*/ input_qzeros + mci * block_m,
|
||||
/*B*/ weight + (nc * Kc + kci) *
|
||||
(BLOCK_N * (_block_k / 2 + sizeof(int32_t))),
|
||||
/*scales_b*/ weight_scales + nc * BLOCK_N * num_groups +
|
||||
kci / block_per_group * BLOCK_N,
|
||||
/*qzeros_b*/ weight_qzeros + nc * BLOCK_N * num_groups +
|
||||
kci / block_per_group * BLOCK_N,
|
||||
/*Bcomp*/ compensation_ptr,
|
||||
/*dqB_tmp*/ dqB_tmp,
|
||||
/*M*/ m_size,
|
||||
/*K*/ _block_k,
|
||||
/*lda*/ K,
|
||||
/*ldc*/ BLOCK_N,
|
||||
/*use_brgemm*/ use_brgemm);
|
||||
}
|
||||
store_out<BLOCK_N>(C_tmp, output + mci * block_m * N + nc * BLOCK_N,
|
||||
m_size, N /*lda*/);
|
||||
}
|
||||
}
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
convert_int4_weight_packed_with_compensation(const at::Tensor& weight,
|
||||
const at::Tensor& scales,
|
||||
const at::Tensor& qzeros) {
|
||||
TORCH_CHECK(weight.dim() == 2,
|
||||
"DA8W4 CPU: Weight should be a 2D tensor for packing");
|
||||
TORCH_CHECK(
|
||||
weight.size(1) % 2 == 0,
|
||||
"DA8W4 CPU: Weight should have even number of columns for packing");
|
||||
|
||||
auto new_scales = scales;
|
||||
auto new_qzeros = qzeros;
|
||||
if (new_scales.dim() == 1) {
|
||||
new_scales.unsqueeze_(1);
|
||||
}
|
||||
new_scales = new_scales.to(at::kFloat);
|
||||
if (new_qzeros.dim() == 1) {
|
||||
new_qzeros.unsqueeze_(1);
|
||||
}
|
||||
new_qzeros = new_qzeros.to(at::kChar);
|
||||
int64_t N = weight.size(0);
|
||||
int64_t K = weight.size(1);
|
||||
int64_t G = scales.size(1);
|
||||
int64_t group_size = K / G;
|
||||
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||
constexpr int block_n = block_size_n();
|
||||
int64_t Nc = N / block_n;
|
||||
int64_t Kc = K / _block_k;
|
||||
|
||||
auto weight_view = weight.view({Nc, block_n, Kc, _block_k});
|
||||
at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous();
|
||||
at::Tensor blocked_weight;
|
||||
at::Tensor blocked_scales =
|
||||
new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
|
||||
at::Tensor blocked_qzeros =
|
||||
new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous();
|
||||
auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) -
|
||||
new_qzeros.view({Nc, block_n, G, -1});
|
||||
weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, _block_k});
|
||||
at::Tensor compensation = weight_sub_qzero.sum(-1);
|
||||
compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt);
|
||||
int64_t buffer_size_nbytes =
|
||||
_block_k * block_n / 2 + block_n * sizeof(int32_t);
|
||||
blocked_weight = at::empty({Nc, Kc, buffer_size_nbytes}, weight.options());
|
||||
|
||||
auto weight_ptr = weight_reordered.data_ptr<uint8_t>();
|
||||
auto compensation_ptr = compensation.data_ptr<int32_t>();
|
||||
auto blocked_weight_ptr = blocked_weight.data_ptr<uint8_t>();
|
||||
int64_t num_blocks = Nc * Kc;
|
||||
at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
auto in_ptr = weight_ptr + i * _block_k * block_n;
|
||||
auto out_ptr =
|
||||
blocked_weight_ptr + i * block_n * (_block_k / 2 + sizeof(int32_t));
|
||||
int32_t* comp_in_prt = compensation_ptr + i * block_n;
|
||||
int32_t* comp_out_prt =
|
||||
(int32_t*)(void*)(blocked_weight_ptr +
|
||||
i * block_n * (_block_k / 2 + sizeof(int32_t)) +
|
||||
_block_k * block_n / 2);
|
||||
constexpr int n_group_size = 8;
|
||||
constexpr int vnni_size = 4;
|
||||
constexpr int n_group = block_n / n_group_size;
|
||||
for (int nb = 0; nb < n_group; nb += 2) {
|
||||
for (int k = 0; k < _block_k; k += vnni_size) {
|
||||
for (int ni = 0; ni < n_group_size; ++ni) {
|
||||
for (int ki = 0; ki < vnni_size; ++ki) {
|
||||
int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n;
|
||||
int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n;
|
||||
int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size +
|
||||
k * block_n / 2 + ki;
|
||||
uint8_t src_1 = *(in_ptr + src_idx_1);
|
||||
uint8_t src_2 = *(in_ptr + src_idx_2);
|
||||
uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4);
|
||||
*(out_ptr + dst_idx) = dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int nb = 0; nb < block_n; nb++) {
|
||||
*(comp_out_prt + nb) = *(comp_in_prt + nb);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales),
|
||||
std::move(blocked_qzeros));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> autoawq_to_int4pack(at::Tensor qweight,
|
||||
at::Tensor qzeros) {
|
||||
auto bitshifts = at::tensor({0, 4, 1, 5, 2, 6, 3, 7}, at::kInt) * 4;
|
||||
auto qweight_unsq = qweight.unsqueeze(-1);
|
||||
auto unpacked = at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF;
|
||||
auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte);
|
||||
|
||||
auto qzeros_unsq = qzeros.unsqueeze(-1);
|
||||
auto qzeros_unpacked = at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF;
|
||||
auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte);
|
||||
|
||||
return std::make_tuple(qweight_final, qzeros_final);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
|
||||
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales) {
|
||||
auto res = autoawq_to_int4pack(qweight, qzeros);
|
||||
auto _qweight = std::get<0>(res);
|
||||
auto _qzeros = std::get<1>(res);
|
||||
auto _scales = scales;
|
||||
_qzeros = _qzeros.transpose(-2, -1).contiguous();
|
||||
_scales = _scales.transpose(-2, -1).contiguous();
|
||||
if (_qweight.dim() == 3) {
|
||||
int64_t E = _qweight.size(0);
|
||||
int64_t K = _qweight.size(2);
|
||||
int64_t G = _scales.size(2);
|
||||
int64_t group_size = K / G;
|
||||
int64_t _block_k = get_4bit_block_k_size(group_size);
|
||||
int64_t block_n = block_size_n();
|
||||
int64_t Nc = _qweight.size(1) / block_n;
|
||||
int64_t Kc = K / _block_k;
|
||||
int64_t buffer_size_nbytes =
|
||||
_block_k * block_n / 2 + block_n * sizeof(int32_t);
|
||||
auto blocked_weight =
|
||||
at::empty({E, Nc, Kc, buffer_size_nbytes}, _qweight.options());
|
||||
auto blocked_scales =
|
||||
at::empty({E, Nc, G, block_n}, _scales.options()).to(at::kFloat);
|
||||
auto blocked_qzeros =
|
||||
at::empty({E, Nc, G, block_n}, _qzeros.options()).to(at::kChar);
|
||||
for (int i = 0; i < _qweight.size(0); i++) {
|
||||
auto res_ = convert_int4_weight_packed_with_compensation(
|
||||
_qweight[i], _scales[i], _qzeros[i]);
|
||||
blocked_weight[i] = std::get<0>(res_);
|
||||
blocked_scales[i] = std::get<1>(res_);
|
||||
blocked_qzeros[i] = std::get<2>(res_);
|
||||
}
|
||||
_qweight = blocked_weight;
|
||||
_scales = blocked_scales;
|
||||
_qzeros = blocked_qzeros;
|
||||
} else {
|
||||
auto res_ = convert_int4_weight_packed_with_compensation(_qweight, _scales,
|
||||
_qzeros);
|
||||
_qweight = std::get<0>(res_);
|
||||
_scales = std::get<1>(res_);
|
||||
_qzeros = std::get<2>(res_);
|
||||
}
|
||||
|
||||
return std::make_tuple(_qweight, _qzeros, _scales);
|
||||
}
|
||||
|
||||
at::Tensor int4_scaled_mm_cpu_with_quant(const at::Tensor& input,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& weight_scales,
|
||||
const at::Tensor& weight_qzeros,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType output_dtype) {
|
||||
RECORD_FUNCTION("vllm::int4_scaled_mm_cpu_with_quant",
|
||||
std::vector<c10::IValue>({input, weight}));
|
||||
|
||||
int64_t M_a = input.size(0);
|
||||
int64_t K_a = input.size(1);
|
||||
int64_t lda = input.stride(0);
|
||||
|
||||
const auto st = input.scalar_type();
|
||||
TORCH_CHECK(
|
||||
st == at::kBFloat16 || st == at::kHalf,
|
||||
"int4_scaled_mm_cpu_with_quant: expect A to be bfloat16 or half.");
|
||||
|
||||
constexpr bool sym_quant_act = false;
|
||||
using Tin = typename ActDtype<sym_quant_act>::type;
|
||||
int64_t act_buffer_size =
|
||||
M_a * K_a + M_a * sizeof(float) + M_a * sizeof(int32_t);
|
||||
auto act_buffer =
|
||||
at::empty({act_buffer_size}, input.options().dtype(at::kByte));
|
||||
auto Aq_data = act_buffer.data_ptr<uint8_t>();
|
||||
auto As_data = reinterpret_cast<float*>(Aq_data + M_a * K_a);
|
||||
auto Azp_data = reinterpret_cast<int32_t*>(As_data + M_a);
|
||||
fill_val_stub(Azp_data, 128, M_a);
|
||||
|
||||
auto out_sizes = input.sizes().vec();
|
||||
int64_t N = weight_scales.size(0) * weight_scales.size(-1);
|
||||
out_sizes.back() = N;
|
||||
auto output = at::empty(out_sizes, input.options());
|
||||
int64_t Nc = weight.size(0);
|
||||
int64_t Kc = weight.size(1);
|
||||
int64_t _block_k = K_a / Kc;
|
||||
TORCH_CHECK(N == Nc * BLOCK_N, "DA8W4: weight and input shapes mismatch");
|
||||
int64_t num_groups = weight_scales.size(1);
|
||||
|
||||
const uint8_t* b_ptr = weight.data_ptr<uint8_t>();
|
||||
const float* b_scales_ptr = weight_scales.data_ptr<float>();
|
||||
const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr<int8_t>();
|
||||
const float* bias_ptr =
|
||||
bias.has_value() ? bias.value().data_ptr<float>() : nullptr;
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t temp_buffer_size = num_threads * BLOCK_M * BLOCK_N * sizeof(float) +
|
||||
num_threads * _block_k * BLOCK_N;
|
||||
auto c_temp_buffer =
|
||||
at::empty({temp_buffer_size}, input.options().dtype(at::kChar));
|
||||
float* c_temp_ptr = (float*)((void*)(c_temp_buffer.data_ptr<int8_t>()));
|
||||
int8_t* dqB_temp_ptr =
|
||||
(int8_t*)((void*)(c_temp_ptr + num_threads * BLOCK_M * BLOCK_N));
|
||||
|
||||
#define LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act) \
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2( \
|
||||
at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, \
|
||||
"int4_scaled_mm_cpu", [&] { \
|
||||
const scalar_t* __restrict__ A_data = input.data_ptr<scalar_t>(); \
|
||||
scalar_t* __restrict__ c_ptr = output.data_ptr<scalar_t>(); \
|
||||
at::parallel_for(0, M_a, 0, [&](int64_t begin, int64_t end) { \
|
||||
for (int64_t m = begin; m < end; ++m) { \
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * K_a, As_data[m], \
|
||||
A_data + m * lda, K_a); \
|
||||
} \
|
||||
}); \
|
||||
_da8w4_linear_impl<sym_quant_act, Tin, scalar_t>( \
|
||||
Aq_data, As_data, Azp_data, b_ptr, b_scales_ptr, b_qzeros_ptr, \
|
||||
bias_ptr, c_ptr, c_temp_ptr, dqB_temp_ptr, M_a, N, K_a, \
|
||||
num_groups); \
|
||||
});
|
||||
|
||||
LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out,
|
||||
const float* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
fVec x0 = fVec::loadu(input + d);
|
||||
fVec x1 = fVec::loadu(input + d + fVec::size());
|
||||
Vec res = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
res.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(scalar_t* C, float* C_temp, const uint8_t* A,
|
||||
const float* scales_a, const int32_t* qzeros_a,
|
||||
const uint8_t* B, const float* scales_b,
|
||||
const int8_t* qzeros_b, const int32_t* compensation,
|
||||
int8_t* dqB_tmp, int64_t M, int64_t K, int64_t lda,
|
||||
int64_t ldc_f, int64_t ldc_s, bool store_out,
|
||||
bool use_brgemm) {
|
||||
_dequant_gemm_accum<false, BLOCK_N, BLOCK_N / 2>(
|
||||
C_temp, A, scales_a, qzeros_a, B, scales_b, qzeros_b, compensation,
|
||||
dqB_tmp, M, K, lda, ldc_f, use_brgemm);
|
||||
if (store_out) {
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
copy_stub<scalar_t>(C + m * ldc_s, C_temp + m * ldc_f, BLOCK_N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
TYPE * C, float* C_temp, const uint8_t* A, const float* scales_a, \
|
||||
const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, \
|
||||
const int8_t* qzeros_b, const int32_t* compensation, int8_t* dqB_tmp, \
|
||||
int64_t M, int64_t K, int64_t lda, int64_t ldc_f, int64_t ldc_s, \
|
||||
bool store_out, bool use_brgemm)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
|
||||
at::Tensor& w_scales,
|
||||
std::optional<at::Tensor> bias) {
|
||||
return int4_scaled_mm_cpu_with_quant(x, w, w_scales, w_zeros, bias,
|
||||
x.scalar_type());
|
||||
}
|
||||
@@ -79,14 +79,6 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni);
|
||||
|
||||
// Adapted from sglang: INT4 W4A8 kernels
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
|
||||
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
|
||||
|
||||
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
|
||||
at::Tensor& w_scales,
|
||||
std::optional<at::Tensor> bias);
|
||||
|
||||
torch::Tensor get_scheduler_metadata(
|
||||
const int64_t num_req, const int64_t num_heads_q,
|
||||
const int64_t num_heads_kv, const int64_t head_dim,
|
||||
@@ -293,18 +285,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
|
||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||
&int8_scaled_mm_with_quant);
|
||||
|
||||
// Adapted from sglang: INT4 W4A8 kernels
|
||||
ops.def(
|
||||
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
|
||||
"Tensor scales) -> (Tensor, Tensor, Tensor)");
|
||||
ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
|
||||
&convert_weight_packed_scale_zp);
|
||||
|
||||
ops.def(
|
||||
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
|
||||
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
|
||||
ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
|
||||
#endif
|
||||
|
||||
// CPU attention kernels
|
||||
|
||||
@@ -6,16 +6,14 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||
|
||||
/*
|
||||
@@ -54,7 +52,7 @@ struct ScaledEpilogueBase {
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::stable::Tensor const& tensor) {
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
@@ -70,8 +68,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(
|
||||
std::optional<torch::stable::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
@@ -120,8 +117,8 @@ struct ScaledEpilogue
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
@@ -163,9 +160,9 @@ struct ScaledEpilogueBias
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -223,11 +220,10 @@ struct ScaledEpilogueBiasAzp
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -302,11 +298,11 @@ struct ScaledEpilogueBiasAzpToken
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -3,14 +3,6 @@
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
|
||||
// (stable ABI) targets. When compiled under the stable ABI target,
|
||||
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
|
||||
// use torch::stable::Tensor instead.
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#endif
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
@@ -23,12 +15,6 @@
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
using TensorType = torch::stable::Tensor;
|
||||
#else
|
||||
using TensorType = torch::Tensor;
|
||||
#endif
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
@@ -98,7 +84,7 @@ struct ScaledEpilogueBase {
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(TensorType const& tensor) {
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
@@ -114,7 +100,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(std::optional<TensorType> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
@@ -172,8 +158,8 @@ struct ScaledEpilogue
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
@@ -217,9 +203,9 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& bias) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -260,9 +246,9 @@ struct ScaledEpilogueColumnBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& bias) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -318,10 +304,10 @@ struct ScaledEpilogueBiasAzp
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& azp_adj,
|
||||
std::optional<TensorType> const& bias) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -394,11 +380,11 @@ struct ScaledEpilogueBiasAzpToken
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& azp_adj,
|
||||
TensorType const& azp,
|
||||
std::optional<TensorType> const& bias) {
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
@@ -27,61 +27,4 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& output_s,
|
||||
int64_t group_size, double eps, double int8_min,
|
||||
double int8_max);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides, bool per_act_token,
|
||||
bool per_out_ch);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
#endif
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,22 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,23 +0,0 @@
|
||||
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,52 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
} // namespace vllm
|
||||
@@ -1,24 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,25 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,24 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,25 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,220 +0,0 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
*/
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm90_fp8,
|
||||
vllm::cutlass_scaled_mm_sm90_int8,
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,451 +0,0 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
} else if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
||||
// or CUDA 12.8 and SM100 (Blackwell)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
}
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
if (version_num >= 120) {
|
||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100 && version_num < 120) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 75) {
|
||||
// Turing
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides, bool per_act_token,
|
||||
bool per_out_ch) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
if (version_num >= 100 && version_num < 110) {
|
||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
blockscale_offsets, is_gated);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
||||
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_batched_moe_mm_data: no "
|
||||
"cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
// bias, azp, azp_adj are all 1d
|
||||
// bias and azp_adj have n elements, azp has m elements
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(!azp ||
|
||||
azp->scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
c.scalar_type());
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
STD_TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@@ -31,78 +31,6 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||
"()");
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
|
||||
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
||||
|
||||
// CUTLASS w8a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
||||
" bool per_out_ch) -> ()");
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
||||
// (token start indices of each expert). In addition to this, it computes
|
||||
// problem sizes for each expert's multiplication used by the two mms called
|
||||
// from fused MoE operation, and arrays with permutations required to shuffle
|
||||
// and de-shuffle the input/output of the fused operation.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k, Tensor? blockscale_offsets, "
|
||||
" bool is_gated) -> ()");
|
||||
|
||||
// compute per-expert problem sizes from expert_first_token_offset
|
||||
// produced by vLLM's moe_permute kernel
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
||||
" Tensor expert_first_token_offset, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int n, int k, bool swap_ab) -> ()");
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM in batched expert format. It takes expert_num_tokens
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
// expert). In addition to this, it computes problem sizes for each expert's
|
||||
// multiplication used by the two mms called from fused MoE operation.
|
||||
ops.def(
|
||||
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" Tensor expert_num_tokens, "
|
||||
" int num_local_experts, int padded_m, "
|
||||
" int n, int k) -> ()");
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -118,31 +46,6 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
TORCH_BOX(&per_token_group_quant_8bit_packed));
|
||||
ops.impl("per_token_group_quant_int8",
|
||||
TORCH_BOX(&per_token_group_quant_int8));
|
||||
|
||||
// CUTLASS scaled_mm ops
|
||||
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
|
||||
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
|
||||
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
|
||||
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
|
||||
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
|
||||
ops.impl("get_cutlass_batched_moe_mm_data",
|
||||
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
|
||||
#endif
|
||||
}
|
||||
|
||||
// These capability-check functions take only primitive args (no tensors), so
|
||||
// there is no device to dispatch on. CompositeExplicitAutograd makes them
|
||||
// available for all backends. This is the stable ABI equivalent of calling
|
||||
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||
#ifndef USE_ROCM
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
|
||||
ops.impl("cutlass_group_gemm_supported",
|
||||
TORCH_BOX(&cutlass_group_gemm_supported));
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
|
||||
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
|
||||
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
|
||||
|
||||
// Utility to get the current CUDA stream for a given device using stable APIs.
|
||||
// Returns a cudaStream_t for use in kernel launches.
|
||||
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
|
||||
|
||||
@@ -21,7 +21,7 @@ struct SSMParamsBase {
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
int64_t null_block_id;
|
||||
int64_t pad_slot_id;
|
||||
|
||||
bool delta_softplus;
|
||||
bool cache_enabled;
|
||||
|
||||
@@ -118,17 +118,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
|
||||
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index;
|
||||
if (cache_indices == nullptr) {
|
||||
cache_index = batch_id;
|
||||
} else if (params.cache_enabled) {
|
||||
const int* initial_state_idx = reinterpret_cast<const int*>(params.initial_state_idx_ptr);
|
||||
cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]];
|
||||
} else {
|
||||
cache_index = cache_indices[batch_id];
|
||||
}
|
||||
// Skip batch entries whose cache index maps to the null block (padding).
|
||||
if (cache_indices != nullptr && cache_index == params.null_block_id){
|
||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
||||
@@ -535,7 +527,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool varlen,
|
||||
int64_t null_block_id,
|
||||
int64_t pad_slot_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
@@ -552,7 +544,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
params.null_block_id = null_block_id;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
@@ -666,7 +658,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t null_block_id,
|
||||
int64_t pad_slot_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
@@ -813,7 +805,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
varlen,
|
||||
null_block_id,
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
|
||||
47
csrc/ops.h
47
csrc/ops.h
@@ -228,18 +228,63 @@ int64_t ggml_moe_get_block_size(int64_t type);
|
||||
#ifndef USE_ROCM
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts,
|
||||
const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout);
|
||||
@@ -298,7 +343,7 @@ void selective_scan_fwd(
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
const std::optional<torch::Tensor>& cache_indices,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor>& initial_state_idx,
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
@@ -54,27 +53,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
static bool nvfp4_quant_sm_supported() {
|
||||
const int32_t sm = get_sm_version_num();
|
||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||
if (sm >= 100 && sm < 120) return true;
|
||||
#endif
|
||||
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||
if (sm >= 120 && sm < 130) return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||
torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||
torch::Tensor& output_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled nvfp4 quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||
is_sf_swizzled_layout);
|
||||
#endif
|
||||
@@ -116,10 +100,6 @@ void scaled_fp4_experts_quant(
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled nvfp4 experts quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
@@ -132,10 +112,6 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
|
||||
torch::Tensor& input, torch::Tensor& input_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
@@ -149,11 +125,6 @@ void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||
"No compiled silu_and_mul nvfp4 experts quantization kernel "
|
||||
"for SM ",
|
||||
get_sm_version_num(),
|
||||
". Recompile with the appropriate CUDA arch.");
|
||||
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
|
||||
@@ -63,17 +63,5 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
||||
int runtimeVersion;
|
||||
cudaRuntimeGetVersion(&runtimeVersion);
|
||||
if (runtimeVersion < 12080) return false;
|
||||
// Only report support when the SM-specific kernel was actually compiled in,
|
||||
// so the Python-side backend selector does not choose CUTLASS and then hit
|
||||
// TORCH_CHECK_NOT_IMPLEMENTED (or worse, fall through to Marlin).
|
||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||
if (cuda_device_capability >= 100 && cuda_device_capability < 120)
|
||||
return true;
|
||||
#endif
|
||||
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||
if (cuda_device_capability >= 120 && cuda_device_capability < 130)
|
||||
return true;
|
||||
#endif
|
||||
return false;
|
||||
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
|
||||
}
|
||||
|
||||
@@ -154,7 +154,6 @@ struct MacheteCollectiveMma {
|
||||
struct DispatchPolicy {
|
||||
constexpr static int Stages = PipelineStages;
|
||||
using ClusterShape = ClusterShape_MNK;
|
||||
using ArchTag = arch::Sm90;
|
||||
using Schedule = KernelScheduleType;
|
||||
};
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@@ -26,14 +25,14 @@
|
||||
namespace vllm::c3x {
|
||||
|
||||
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||
torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
|
||||
torch::Tensor const& a, torch::Tensor const& b) {
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
return {m, n, k, 1};
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(
|
||||
torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
@@ -51,20 +50,19 @@ void cutlass_gemm_caller(
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
@@ -4,12 +4,13 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||
@@ -0,0 +1,23 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@@ -132,10 +130,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
@@ -202,11 +200,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@@ -140,10 +138,10 @@ struct sm120_blockwise_fp8_config_M64 {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
@@ -198,11 +196,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
int M = a.size(0);
|
||||
if (M <= 256) {
|
||||
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
|
||||
@@ -0,0 +1,24 @@
|
||||
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
@@ -103,10 +101,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
@@ -122,7 +120,7 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
|
||||
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
|
||||
STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||
|
||||
StrideA a_stride;
|
||||
StrideB b_stride;
|
||||
@@ -163,11 +161,11 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
// TODO: better heuristics
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||
@@ -1,57 +1,52 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/all.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||
void dispatch_scaled_mm(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias,
|
||||
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias,
|
||||
Fp8Func fp8_func, Int8Func int8_func,
|
||||
BlockwiseFunc blockwise_func) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
|
||||
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
fp8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
||||
int8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
STD_TORCH_CHECK(
|
||||
TORCH_CHECK(
|
||||
false, "Int8 not supported on SM", version_num,
|
||||
". Use FP8 quantization instead, or run on older arch (SM < 100).");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
STD_TORCH_CHECK(
|
||||
TORCH_CHECK(
|
||||
a.size(0) == a_scales.size(0) &&
|
||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
||||
"a_scale_group_shape must be [1, 128].");
|
||||
STD_TORCH_CHECK(
|
||||
TORCH_CHECK(
|
||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
||||
"b_scale_group_shape must be [128, 128].");
|
||||
}
|
||||
|
||||
STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
blockwise_func(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
56
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp
Normal file
56
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp
Normal file
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
} // namespace vllm
|
||||
23
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu
Normal file
23
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu
Normal file
@@ -0,0 +1,23 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
@@ -194,9 +192,8 @@ struct sm100_fp8_config_M16_swap_ab {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -240,15 +237,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm100_fp8_dispatch(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
|
||||
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm100_fp8_config_default<InType, OutType,
|
||||
@@ -295,24 +292,22 @@ inline void cutlass_gemm_sm100_fp8_dispatch(
|
||||
}
|
||||
|
||||
template <bool EnableBias, typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
24
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu
Normal file
24
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu
Normal file
@@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
@@ -140,15 +138,13 @@ struct sm120_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
int M = a.size(0);
|
||||
|
||||
@@ -181,21 +177,19 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
23
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu
Normal file
23
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu
Normal file
@@ -0,0 +1,23 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
@@ -237,9 +235,8 @@ struct sm90_fp8_config_M16_N8192 {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller_sm90_fp8(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -283,15 +280,15 @@ void cutlass_gemm_caller_sm90_fp8(torch::stable::Tensor& out,
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config_default<InType, OutType,
|
||||
@@ -350,24 +347,22 @@ inline void cutlass_gemm_sm90_fp8_dispatch(
|
||||
}
|
||||
|
||||
template <bool EnableBias, typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
24
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu
Normal file
24
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu
Normal file
@@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
@@ -89,13 +87,13 @@ struct sm90_int8_config_M32_NSmall {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_int8_config_default<InType, OutType,
|
||||
@@ -144,19 +142,19 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::stable::Tensor& out,
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
@@ -51,39 +51,32 @@ __global__ void get_group_gemm_starts(
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
|
||||
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
|
||||
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
STD_TORCH_CHECK(a_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long);
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
|
||||
cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
else {
|
||||
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
@@ -85,17 +84,13 @@ struct cutlass_3x_group_gemm {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_group_gemm_caller(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -103,20 +98,16 @@ void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
auto device = a_tensors.device();
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
|
||||
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor out_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||
@@ -165,7 +156,7 @@ void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
int device_id = a_tensors.get_device_index();
|
||||
int device_id = a_tensors.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
@@ -179,9 +170,9 @@ void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
@@ -1,8 +1,7 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
@@ -63,27 +62,21 @@ struct sm100_fp8_config_N8192 {
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
void run_cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
STD_TORCH_CHECK(
|
||||
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
@@ -114,18 +107,14 @@ void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
void dispatch_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
@@ -138,17 +127,13 @@ void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
@@ -1,8 +1,7 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
@@ -104,27 +103,21 @@ struct sm90_fp8_config_N8192 {
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
void run_cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
STD_TORCH_CHECK(
|
||||
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
@@ -170,18 +163,14 @@ void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
void dispatch_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
@@ -196,17 +185,13 @@ void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
@@ -1,11 +1,9 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -112,22 +110,19 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab,
|
||||
const bool is_gated) {
|
||||
inline void launch_compute_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab, const bool is_gated) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
auto const* topk_ptr = topk_ids.const_data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.mutable_data_ptr<int32_t>();
|
||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
|
||||
|
||||
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
@@ -176,53 +171,46 @@ __global__ void compute_problem_sizes_from_expert_offsets(
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab) {
|
||||
STD_TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
STD_TORCH_CHECK(expert_first_token_offset.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
|
||||
"expert_first_token_offset must be int64");
|
||||
|
||||
STD_TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes1.scalar_type() == torch::headeronly::ScalarType::Int &&
|
||||
problem_sizes2.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
STD_TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
STD_TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
STD_TORCH_CHECK(problem_sizes1.size(0) == problem_sizes2.size(0) &&
|
||||
problem_sizes1.size(1) == problem_sizes2.size(1),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
|
||||
problem_sizes2.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
|
||||
int64_t const num_experts64 = problem_sizes1.size(0);
|
||||
STD_TORCH_CHECK(
|
||||
expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
STD_TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
STD_TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX,
|
||||
"n and k must fit in int32");
|
||||
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
|
||||
|
||||
int const num_experts = static_cast<int>(num_experts64);
|
||||
auto stream =
|
||||
get_current_cuda_stream(expert_first_token_offset.get_device_index());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(
|
||||
expert_first_token_offset.device().index());
|
||||
|
||||
int const threads = (num_experts < 256) ? num_experts : 256;
|
||||
int const blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
auto const* offsets_ptr = expert_first_token_offset.const_data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
|
||||
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes_from_expert_offsets<SwapAB>
|
||||
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
||||
num_experts, static_cast<int>(n),
|
||||
@@ -231,19 +219,16 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
auto device = topk_ids.device();
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
torch::stable::Tensor atomic_buffer = torch::stable::new_zeros(
|
||||
topk_ids, {num_experts}, torch::headeronly::ScalarType::Int);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
@@ -305,13 +290,11 @@ __global__ void compute_batched_moe_data(
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = get_current_cuda_stream(expert_offsets.get_device_index());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
|
||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||
@@ -328,4 +311,4 @@ void get_cutlass_batched_moe_mm_data_caller(
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
}
|
||||
}
|
||||
199
csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu
Normal file
199
csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu
Normal file
@@ -0,0 +1,199 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/all.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
*/
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
#pragma once
|
||||
#include <stddef.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
@@ -96,9 +95,8 @@ struct cutlass_2x_gemm {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
@@ -151,12 +149,11 @@ inline void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto device = a.device();
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
@@ -164,9 +161,9 @@ inline void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
}
|
||||
|
||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||
inline void fallback_cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
// In some cases, the GPU isn't able to accommodate the
|
||||
// shared memory requirements of the Gemm. In such cases, use
|
||||
@@ -183,8 +180,8 @@ inline void fallback_cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||
std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
return cutlass_gemm_caller<FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -72,13 +70,13 @@ struct sm75_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm75_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -74,13 +72,13 @@ struct sm80_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm80_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
@@ -36,12 +34,10 @@ struct sm89_fp8_config_default {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -88,12 +84,10 @@ struct sm89_fp8_config_M256 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -131,12 +125,10 @@ struct sm89_fp8_config_M128 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -181,12 +173,10 @@ struct sm89_fp8_config_M64 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -237,12 +227,10 @@ struct sm89_fp8_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -292,12 +280,10 @@ struct sm89_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -340,15 +326,13 @@ struct sm89_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -34,11 +32,10 @@ struct sm89_int8_config_default {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -91,11 +88,10 @@ struct sm89_int8_config_M256 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -147,11 +143,10 @@ struct sm89_int8_config_M128 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -198,11 +193,10 @@ struct sm89_int8_config_M64 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -240,11 +234,10 @@ struct sm89_int8_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -283,11 +276,10 @@ struct sm89_int8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -319,13 +311,13 @@ struct sm89_int8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
@@ -8,12 +8,11 @@
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
|
||||
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm100_fp8,
|
||||
nullptr, // int8 not supported on SM100
|
||||
@@ -8,12 +8,11 @@
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
|
||||
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm120_fp8,
|
||||
nullptr, // int8 not supported on SM120
|
||||
36
csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu
Normal file
36
csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu
Normal file
@@ -0,0 +1,36 @@
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm90_fp8,
|
||||
vllm::cutlass_scaled_mm_sm90_int8,
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
420
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
Normal file
420
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
Normal file
@@ -0,0 +1,420 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
} else if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
||||
// or CUDA 12.8 and SM100 (Blackwell)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
}
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
if (version_num >= 120) {
|
||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100 && version_num < 120) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 75) {
|
||||
// Turing
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
if (version_num >= 100 && version_num < 110) {
|
||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
blockscale_offsets, is_gated);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
||||
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts,
|
||||
const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No compiled get_cutlass_batched_moe_mm_data: no "
|
||||
"cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num,
|
||||
". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
// bias, azp, azp_adj are all 1d
|
||||
// bias and azp_adj have n elements, azp has m elements
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@@ -439,6 +439,90 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported);
|
||||
|
||||
// CUTLASS w8a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
||||
" bool per_out_ch) -> ()");
|
||||
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
||||
// (token start indices of each expert). In addition to this, it computes
|
||||
// problem sizes for each expert's multiplication used by the two mms called
|
||||
// from fused MoE operation, and arrays with permutations required to shuffle
|
||||
// and de-shuffle the input/output of the fused operation.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k, Tensor? blockscale_offsets, "
|
||||
" bool is_gated) -> ()");
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// compute per-expert problem sizes from expert_first_token_offset
|
||||
// produced by vLLM's moe_permute kernel
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
||||
" Tensor expert_first_token_offset, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int n, int k, bool swap_ab) -> ()");
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM in batched expert format. It takes expert_num_tokens
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
// expert). In addition to this, it computes problem sizes for each expert's
|
||||
// multiplication used by the two mms called from fused MoE operation.
|
||||
ops.def(
|
||||
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" Tensor expert_num_tokens, "
|
||||
" int num_local_experts, int padded_m, "
|
||||
" int n, int k) -> ()");
|
||||
ops.impl("get_cutlass_batched_moe_mm_data", torch::kCUDA,
|
||||
&get_cutlass_batched_moe_mm_data);
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
&cutlass_scaled_mm_supports_block_fp8);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
|
||||
@@ -556,7 +640,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"Tensor! ssm_states,"
|
||||
"int null_block_id,"
|
||||
"int pad_slot_id,"
|
||||
"int block_size,"
|
||||
"Tensor? block_idx_first_scheduled_token,"
|
||||
"Tensor? block_idx_last_scheduled_token,"
|
||||
|
||||
@@ -590,10 +590,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
# Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
|
||||
# https://docs.flashinfer.ai/installation.html
|
||||
# From versions.json: .flashinfer.version
|
||||
# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798)
|
||||
# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on
|
||||
# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738
|
||||
ARG FLASHINFER_VERSION=0.6.7
|
||||
ARG FLASHINFER_VERSION=0.6.6
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
|
||||
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
|
||||
@@ -217,16 +217,13 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
|
||||
|
||||
|
||||
# build flashinfer for torch nightly from source around 10 mins
|
||||
# release version: v0.6.7
|
||||
# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798)
|
||||
# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on
|
||||
# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738
|
||||
# release version: v0.6.6
|
||||
# todo(elainewy): cache flashinfer build result for faster build
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
echo "git clone flashinfer..." \
|
||||
&& git clone --depth 1 --branch v0.6.7 --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
||||
&& git clone --depth 1 --branch v0.6.6 --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
||||
&& cd flashinfer \
|
||||
&& git submodule update --init --recursive \
|
||||
&& echo "finish git clone flashinfer..." \
|
||||
|
||||
@@ -111,9 +111,12 @@ CMD ["/bin/bash"]
|
||||
|
||||
FROM vllm-base AS vllm-openai
|
||||
|
||||
# install development dependencies (for testing)
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -e tests/vllm_test_utils
|
||||
uv pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN uv pip install -e tests/vllm_test_utils
|
||||
|
||||
# install NIXL and UCX from source code
|
||||
ARG UCX_VERSION=e5d98879705239d254ede40b4a52891850cb5349
|
||||
|
||||
@@ -68,7 +68,7 @@
|
||||
"default": "true"
|
||||
},
|
||||
"FLASHINFER_VERSION": {
|
||||
"default": "0.6.7"
|
||||
"default": "0.6.6"
|
||||
},
|
||||
"GDRCOPY_CUDA_VERSION": {
|
||||
"default": "12.8"
|
||||
|
||||
@@ -23,7 +23,7 @@ Now supports 6 types of connectors:
|
||||
|
||||
- **ExampleConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of ExampleConnector disaggregated prefilling.
|
||||
- **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
|
||||
- **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). For feature compatibility details, see [NixlConnector Compatibility Matrix](nixl_connector_compatibility.md).
|
||||
- **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md).
|
||||
- **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling.
|
||||
- **MooncakeConnector**: refer to [examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh) for the example usage of ExampleConnector disaggregated prefilling. For detailed usage guide, see [MooncakeConnector Usage Guide](mooncake_connector_usage.md).
|
||||
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
# NixlConnector Compatibility Matrix
|
||||
|
||||
This page documents the feature compatibility of **disaggregated prefilling with the NixlConnector**. For general usage instructions, see the [NixlConnector Usage Guide](nixl_connector_usage.md). For an overview of disaggregated prefilling, see [Disaggregated Prefilling](disagg_prefill.md).
|
||||
|
||||
!!! note
|
||||
This page reflects the current state of the codebase and is subject to change as features evolve. Entries marked 🟠 or ❌ may link to tracking issues. See the [NIXL connector roadmap](https://github.com/vllm-project/vllm/issues/33702) for upcoming feature development.
|
||||
|
||||
**Legend:**
|
||||
|
||||
- ✅ = Fully supported
|
||||
- 🟠 = Partial support (see footnotes)
|
||||
- ❌ = Not supported
|
||||
- ❔ = Unknown / not yet validated
|
||||
- 🚧 = Work in progress
|
||||
|
||||
!!! info "Universally supported features"
|
||||
The following features work with **all** model architectures when using NixlConnector PD disaggregated serving:
|
||||
|
||||
[Chunked Prefill](../configuration/optimization.md#chunked-prefill) |
|
||||
[APC (Prefix Caching)](automatic_prefix_caching.md) |
|
||||
[Data Parallel](../serving/data_parallel_deployment.md) |
|
||||
CUDA graph |
|
||||
Logprobs |
|
||||
Prompt Logprobs |
|
||||
[Prompt Embeds](prompt_embeds.md) |
|
||||
Multiple NIXL backends (UCX, GDS, LIBFABRIC, etc.)
|
||||
|
||||
## Model Architecture x Capability
|
||||
|
||||
<style>
|
||||
td:not(:first-child) {
|
||||
text-align: center !important;
|
||||
}
|
||||
td {
|
||||
padding: 0.5rem !important;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
th {
|
||||
padding: 0.5rem !important;
|
||||
min-width: 0 !important;
|
||||
}
|
||||
|
||||
th:not(:first-child) {
|
||||
writing-mode: vertical-lr;
|
||||
transform: rotate(180deg)
|
||||
}
|
||||
</style>
|
||||
|
||||
| Model type | <abbr title="Basic Prefill/Decode disaggregation">Basic PD</abbr> | <abbr title="Speculative Decoding">Spec Decode</abbr> | <abbr title="Heterogeneous Tensor Parallelism (P TP != D TP)">Hetero TP</abbr> | <abbr title="Cross-layer blocks optimization">Cross-layer blocks</abbr> | <abbr title="Sliding Window Attention">SWA</abbr> | <abbr title="CPU host buffer offload (e.g. TPU)">Host buffer</abbr> | <abbr title="Different block sizes on P and D">Hetero block size</abbr> |
|
||||
| - | - | - | - | - | - | - | - |
|
||||
| Dense Transformers | ✅ | ✅<sup>1</sup> | ✅ | ✅<sup>2</sup> | ✅ | ✅ | 🟠<sup>3</sup> |
|
||||
| MLA (e.g. DeepSeek-V2/V3) | ✅ | ✅<sup>1</sup> | 🟠<sup>4</sup> | ✅<sup>2</sup> | ✅ | ✅ | 🟠<sup>3</sup> |
|
||||
| Sparse MLA (e.g. DeepSeek-V3.2) | ✅ | ✅<sup>1</sup> | 🟠<sup>4</sup> | ✅<sup>2</sup> | ✅ | ✅ | 🟠<sup>3</sup> |
|
||||
| Hybrid SSM / Mamba | ✅ | ❔ | 🚧<sup>5</sup> | ❌ | ✅ | ✅ | ❌<sup>6</sup> |
|
||||
| MoE | ✅ | ✅<sup>1</sup> | ✅ | ✅<sup>2</sup> | ✅ | ✅ | 🟠<sup>3</sup> |
|
||||
| Multimodal | ❔ | ❔ | ❔ | ❔ | ❔ | ❔ | ❔ |
|
||||
| Encoder-Decoder | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
<sup>1</sup> P and D instances must use the same speculation configuration.
|
||||
|
||||
<sup>2</sup> Requires `FLASH_ATTN` or `FLASHINFER` backend **and** `HND` KV cache layout. Enable via `--kv-transfer-config '{"kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'`.
|
||||
|
||||
<sup>3</sup> Supported only when HMA is **not** required (i.e., non-hybrid models). Block IDs are remapped automatically. Only P block size < D block size is supported.
|
||||
|
||||
<sup>4</sup> MLA KV cache is replicated across TP workers, so heterogeneous TP works but there is no head-splitting. When P TP > D TP, only a single read is executed (redundant ranks are skipped). D TP > P TP also works.
|
||||
|
||||
<sup>5</sup> Hybrid SSM (Mamba) models require **homogeneous TP** (`P TP == D TP`). Heterogeneous TP is not yet supported for Mamba layers.
|
||||
|
||||
<sup>6</sup> HMA (required by hybrid models) does not support different remote block sizes.
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### What must match between P and D
|
||||
|
||||
By default, a **compatibility hash** is checked during handshake. P and D instances must agree on:
|
||||
|
||||
- vLLM version and NIXL connector version
|
||||
- Model (architecture, dtype, number of KV heads, head size, number of hidden layers)
|
||||
- Attention backend
|
||||
- KV cache dtype (`cache_dtype`)
|
||||
|
||||
!!! warning
|
||||
Disable the hash check with `--kv-transfer-config '{"kv_connector_extra_config": {"enforce_handshake_compat": false}}'` at your own risk.
|
||||
|
||||
### What can safely differ between P and D
|
||||
|
||||
- `tensor-parallel-size` (heterogeneous TP, subject to model restrictions above)
|
||||
- `block-size` (heterogeneous block size, subject to restrictions above)
|
||||
- Number of KV cache blocks (determined by available memory on each instance)
|
||||
|
||||
### KV cache layout
|
||||
|
||||
- NixlConnector defaults to **`HND`** layout for optimal transfer performance (non-MLA models).
|
||||
- `NHD` layout is supported but does **not** allow heterogeneous TP head splitting.
|
||||
- Experimental `HND` ↔ `NHD` permute: enable via `--kv-transfer-config '{"enable_permute_local_kv": true}'`. Not supported with HMA.
|
||||
|
||||
### Quantized KV cache
|
||||
|
||||
[Quantized KV cache](quantization/quantized_kvcache.md) (e.g., FP8) requires both P and D instances to use the **same** `cache_dtype`. Mismatched cache dtypes will fail the compatibility hash check during handshake.
|
||||
|
||||
- **Static quantization** (scales loaded from checkpoint): ✅ Supported. Scales are loaded independently by each instance from the model checkpoint.
|
||||
- **Dynamic quantization** (scales computed at runtime): ❌ Not supported. Per-block scales are not transferred alongside KV cache data.
|
||||
- **Packed-layout scales** (scales stored inline with weights): ✅ Supported. Scales are transferred together with the KV cache blocks.
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer.
|
||||
|
||||
For feature compatibility details (supported model architectures, TP configurations, and feature interactions), see the [NixlConnector Compatibility Matrix](nixl_connector_compatibility.md).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -244,12 +244,12 @@ response = client.chat.completions.create(
|
||||
|
||||
Some models, such as [Qwen3](https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html#thinking-budget), [DeepSeek](https://www.alibabacloud.com/help/en/model-studio/deep-thinking), and [Nemotron3](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16), support a thinking budget that limits the maximum number of tokens used for reasoning.
|
||||
|
||||
Token counting starts from `think_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `think_end_str`, effectively terminating the reasoning block.
|
||||
Token counting starts from `reasoning_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `reasoning_end_str`, effectively terminating the reasoning block.
|
||||
|
||||
To use this feature:
|
||||
|
||||
- `--reasoning-parser` enables reasoning extraction.
|
||||
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `think_start_str`, `think_end_str`).
|
||||
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`).
|
||||
- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit.
|
||||
|
||||
If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`.
|
||||
@@ -257,20 +257,20 @@ If `thinking_token_budget` is not specified, no explicit reasoning limit is appl
|
||||
`--reasoning-config` accepts a JSON object corresponding to
|
||||
[ReasoningConfig][vllm.config.ReasoningConfig] with the following fields:
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------------------|----------------|--------------------------------------------------|
|
||||
| `think_start_str` | `str \| null` | String that marks the start of reasoning content |
|
||||
| `think_end_str` | `str \| null` | String that marks the end of reasoning content |
|
||||
| Field | Type | Description |
|
||||
|-----------------------|----------------|--------------------------------------------------|
|
||||
| `reasoning_start_str` | `str \| null` | String that marks the start of reasoning content |
|
||||
| `reasoning_end_str` | `str \| null` | String that marks the end of reasoning content |
|
||||
|
||||
!!! note
|
||||
`think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
|
||||
`reasoning_end_str` can include a transition phrase before the reasoning end token. For example, setting `reasoning_end_str` to `"I have to give the solution based on the reasoning directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
|
||||
|
||||
### Online Serving
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-0.6B \
|
||||
--reasoning-parser qwen3 \
|
||||
--reasoning-config '{"think_start_str": "<think>", "think_end_str": "I have to give the solution based on the thinking directly now.</think>"}'
|
||||
--reasoning-config '{"reasoning_start_str": "<think>", "reasoning_end_str": "I have to give the solution based on the reasoning directly now.</think>"}'
|
||||
```
|
||||
|
||||
Then make a request with `thinking_token_budget` to limit the reasoning tokens:
|
||||
@@ -298,8 +298,8 @@ from vllm.config import ReasoningConfig
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
reasoning_config=ReasoningConfig(
|
||||
think_start_str="<think>",
|
||||
think_end_str="I have to give the solution based on the thinking directly now.</think>",
|
||||
reasoning_start_str="<think>",
|
||||
reasoning_end_str="I have to give the solution based on the thinking directly now.</think>",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -505,7 +505,7 @@ Here is a summary of a plugin file:
|
||||
|
||||
# adjust request. e.g.: set skip special tokens
|
||||
# to False for tool call output.
|
||||
def adjust_request(self, request: ChatCompletionRequest | ResponsesRequest) -> ChatCompletionRequest | ResponsesRequest:
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
return request
|
||||
|
||||
# implement the tool call parse for stream call
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
<!-- markdownlint-disable MD041 MD051 -->
|
||||
--8<-- [start:installation]
|
||||
|
||||
vLLM supports AMD GPUs with ROCm 6.3 or above. Pre-built wheels are available for ROCm 7.0 and ROCm 7.2.1.
|
||||
|
||||
#### Prebuilt Wheels
|
||||
|
||||
| ROCm Variant | Python Version | ROCm Version | glibc Requirement | Supported Versions |
|
||||
| ------------ | -------------- | ------------ | ----------------- | ------------------ |
|
||||
| `rocm700` | 3.12 | 7.0 | >= 2.35 | `0.14.0` to `0.18.0` |
|
||||
| `rocm721` | 3.12 | 7.2.1 | >= 2.35 | Nightly releases after commit `171775f306a333a9cf105bfd533bf3e113d401d9` |
|
||||
vLLM supports AMD GPUs with ROCm 6.3 or above. Pre-built wheels are available for ROCm 7.0.
|
||||
|
||||
--8<-- [end:installation]
|
||||
--8<-- [start:requirements]
|
||||
@@ -30,112 +23,26 @@ If you need a different ROCm version or want to use an existing PyTorch installa
|
||||
To install the latest version of vLLM for Python 3.12, ROCm 7.0 and `glibc >= 2.35`.
|
||||
|
||||
```bash
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/ --upgrade
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/
|
||||
```
|
||||
|
||||
!!! tip
|
||||
You can find out about which ROCm version the latest vLLM supports by checking the `vllm` package in index in extra-index-url <https://wheels.vllm.ai/rocm/> at [https://wheels.vllm.ai/rocm/vllm](https://wheels.vllm.ai/rocm/vllm) .
|
||||
|
||||
Another approach is that you can use this following commands to automatically extract the wheel variants:
|
||||
|
||||
```bash
|
||||
# automatically extract the available rocm variant
|
||||
export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/vllm | grep -oP 'rocm\d+' | head -1)
|
||||
|
||||
# automatically extract the vLLM version
|
||||
export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/vllm | grep -oP 'vllm-\K[0-9.]+' | head -1)
|
||||
|
||||
# inspect if the ROCm version is compatible with your environment
|
||||
echo $VLLM_ROCM_VARIANT
|
||||
echo $VLLM_VERSION
|
||||
```
|
||||
You can find out about which ROCm version the latest vLLM supports by checking the index in extra-index-url [https://wheels.vllm.ai/rocm/](https://wheels.vllm.ai/rocm/) .
|
||||
|
||||
To install a specific version and ROCm variant of vLLM wheel.
|
||||
|
||||
```bash
|
||||
# version without the `v`
|
||||
uv pip install vllm==${VLLM_VERSION} --extra-index-url https://wheels.vllm.ai/rocm/${VLLM_VERSION}/${VLLM_ROCM_VARIANT}
|
||||
|
||||
# Example
|
||||
uv pip install vllm==0.18.0 --extra-index-url https://wheels.vllm.ai/rocm/0.18.0/rocm700
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/0.15.0/rocm700
|
||||
```
|
||||
|
||||
!!! warning "Caveats for using `pip`"
|
||||
|
||||
We recommend leveraging `uv` to install the vLLM wheel. Using `pip` to install from custom indices is cumbersome because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version. This makes it difficult to install a wheel from a custom index unless exact versions of all packages are specified. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes).
|
||||
We recommend leveraging `uv` to install vLLM wheel. Using `pip` to install from custom indices is cumbersome, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install wheel from custom index if exact versions of all packages are specified exactly. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes).
|
||||
|
||||
If you insist on using `pip`, you need to specify the exact vLLM version in the package name and provide the custom index URL `https://wheels.vllm.ai/rocm/${VLLM_VERSION}/${VLLM_ROCM_VARIANT}` via `--extra-index-url`.
|
||||
If you insist on using `pip`, you have to specify the exact vLLM version and full URL of the wheel path `https://wheels.vllm.ai/rocm/<version>/<rocm-variant>` (which can be obtained from the web page).
|
||||
|
||||
```bash
|
||||
pip install vllm==0.18.0+rocm700 --extra-index-url https://wheels.vllm.ai/rocm/0.18.0/rocm700
|
||||
```
|
||||
|
||||
#### Install the latest code
|
||||
|
||||
LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for every commit since commit `171775f306a333a9cf105bfd533bf3e113d401d9` on <https://wheels.vllm.ai/rocm/nightly/>. The custom index to be used is `https://wheels.vllm.ai/rocm/nightly/${VLLM_ROCM_VARIANT}`
|
||||
|
||||
**NOTE:** The first ROCm Variant that supports nightly wheel is ROCm 7.2.1
|
||||
|
||||
To install from latest nightly index, run:
|
||||
|
||||
```bash
|
||||
# automatically extract the available rocm variant
|
||||
export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/nightly | \
|
||||
grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g')
|
||||
|
||||
# inspect if the ROCm version is compatible with your environment
|
||||
echo $VLLM_ROCM_VARIANT
|
||||
|
||||
uv pip install --pre vllm \
|
||||
--extra-index-url https://wheels.vllm.ai/rocm/nightly/${VLLM_ROCM_VARIANT} \
|
||||
--index-strategy unsafe-best-match
|
||||
```
|
||||
|
||||
##### Install specific revisions
|
||||
|
||||
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL, example:
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=5b8c30d62b754b575e043ce2fc0dcbf8a64f6306
|
||||
|
||||
export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT} | \
|
||||
grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g')
|
||||
|
||||
# Extract the version from the wheel URL
|
||||
export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT}/vllm/ | \
|
||||
grep -oP 'vllm-\K[^-]+' | head -1 | sed 's/%2B/+/g')
|
||||
|
||||
# inspect the version if it is compatible with the ROCm version of your environment
|
||||
echo $VLLM_ROCM_VARIANT
|
||||
echo $VLLM_VERSION
|
||||
|
||||
uv pip install vllm==${VLLM_VERSION} \
|
||||
--extra-index-url https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT} \
|
||||
--index-strategy unsafe-best-match
|
||||
```
|
||||
|
||||
!!! warning "`pip` caveat"
|
||||
|
||||
Using `pip` to install from nightly indices is _not supported_, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes).
|
||||
|
||||
If you insist on using `pip`, you need to specify the exact vLLM version in the package name and provide the custom index URL (which can be obtained from the web page).
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=5b8c30d62b754b575e043ce2fc0dcbf8a64f6306
|
||||
|
||||
export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT} | \
|
||||
grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g')
|
||||
|
||||
# Extract the version from the wheel URL
|
||||
export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT}/vllm/ | \
|
||||
grep -oP 'vllm-\K[^-]+' | head -1 | sed 's/%2B/+/g')
|
||||
|
||||
# inspect the version if it is compatible with the ROCm version of your environment
|
||||
echo $VLLM_ROCM_VARIANT
|
||||
echo $VLLM_VERSION
|
||||
|
||||
pip install vllm==${VLLM_VERSION} \
|
||||
--extra-index-url https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT}
|
||||
pip install vllm==0.15.0+rocm700 --extra-index-url https://wheels.vllm.ai/rocm/0.15.0/rocm700
|
||||
```
|
||||
|
||||
--8<-- [end:pre-built-wheels]
|
||||
@@ -286,24 +193,6 @@ docker run --rm \
|
||||
--model Qwen/Qwen3-0.6B
|
||||
```
|
||||
|
||||
To use the docker image as base for development, you can launch it in interactive session through overriding the entrypoint.
|
||||
|
||||
???+ console "Commands"
|
||||
```bash
|
||||
docker run --rm -it \
|
||||
--group-add=video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HF_TOKEN=$HF_TOKEN" \
|
||||
--network=host \
|
||||
--ipc=host \
|
||||
--entrypoint /bin/bash \
|
||||
vllm/vllm-openai-rocm:<tag>
|
||||
```
|
||||
|
||||
#### Use AMD's Docker Images (Deprecated)
|
||||
|
||||
!!! warning "Deprecated"
|
||||
|
||||
@@ -66,10 +66,6 @@ Restrict domains that vLLM can access for media URLs by setting
|
||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||
|
||||
This protection applies to both the online serving API (multimodal inputs) and
|
||||
the **batch runner** (`vllm run-batch`), where `file_url` values in batch
|
||||
transcription/translation requests are validated against the same allowlist.
|
||||
|
||||
Without domain restrictions, a malicious user could supply URLs that:
|
||||
|
||||
- **Target internal services**: Access internal network endpoints, cloud metadata
|
||||
|
||||
@@ -4,10 +4,9 @@
|
||||
experimental support for tensor-parallel inference with torchrun,
|
||||
see https://github.com/vllm-project/vllm/issues/11400 for
|
||||
the motivation and use case for this example.
|
||||
run the script with `torchrun --nproc-per-node=4 torchrun_example.py`,
|
||||
the argument `4` should match the product of `tensor_parallel_size` and
|
||||
`pipeline_parallel_size` below. see `tests/distributed/test_torchrun_example.py`
|
||||
for the unit test.
|
||||
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
|
||||
the argument 2 should match the `tensor_parallel_size` below.
|
||||
see `tests/distributed/test_torchrun_example.py` for the unit test.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -26,13 +26,8 @@ MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash disagg_1e1p1d_example.sh
|
||||
|
||||
# Use specific storage path
|
||||
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh
|
||||
|
||||
# Run on XPU; scripts switch from CUDA_VISIBLE_DEVICES to ZE_AFFINITY_MASK
|
||||
DEVICE_PLATFORM=xpu GPU_E=0 GPU_PD=1 bash disagg_1e1pd_example.sh
|
||||
```
|
||||
|
||||
`DEVICE_PLATFORM` defaults to `cuda`. Set `DEVICE_PLATFORM=xpu` when running these examples on Intel GPUs so the scripts use `ZE_AFFINITY_MASK` instead of `CUDA_VISIBLE_DEVICES` for device selection.
|
||||
|
||||
## Encoder Instances
|
||||
|
||||
Encoder engines should be launched with the following flags:
|
||||
|
||||
@@ -19,29 +19,11 @@ GPU_E="${GPU_E:-2}"
|
||||
GPU_P="${GPU_P:-2}"
|
||||
GPU_D="${GPU_D:-3}"
|
||||
|
||||
# Device platform and affinity env name.
|
||||
# DEVICE_PLATFORM supports: cuda, xpu
|
||||
DEVICE_PLATFORM="${DEVICE_PLATFORM:-cuda}"
|
||||
if [[ -z "${DEVICE_AFFINITY_ENV:-}" ]]; then
|
||||
if [[ "${DEVICE_PLATFORM,,}" == "xpu" ]]; then
|
||||
DEVICE_AFFINITY_ENV="ZE_AFFINITY_MASK"
|
||||
else
|
||||
DEVICE_AFFINITY_ENV="CUDA_VISIBLE_DEVICES"
|
||||
fi
|
||||
fi
|
||||
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||
|
||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||
|
||||
# Serve args
|
||||
GPU_MEMORY_UTILIZATION_E="${GPU_MEMORY_UTILIZATION_E:-0.01}"
|
||||
GPU_MEMORY_UTILIZATION_P="${GPU_MEMORY_UTILIZATION_P:-0.7}"
|
||||
GPU_MEMORY_UTILIZATION_D="${GPU_MEMORY_UTILIZATION_D:-0.7}"
|
||||
MAX_NUM_SEQS="${MAX_NUM_SEQS:-128}"
|
||||
MAX_MODEL_LEN="${MAX_MODEL_LEN:-32768}"
|
||||
|
||||
export UCX_TLS=all
|
||||
export UCX_NET_DEVICES=all
|
||||
|
||||
@@ -110,14 +92,14 @@ mkdir -p "$EC_SHARED_STORAGE_PATH"
|
||||
###############################################################################
|
||||
# Encoder worker
|
||||
###############################################################################
|
||||
env "$DEVICE_AFFINITY_ENV=$GPU_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_E" \
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--port "$ENCODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
@@ -133,16 +115,15 @@ PIDS+=($!)
|
||||
###############################################################################
|
||||
# Prefill worker
|
||||
###############################################################################
|
||||
env "$DEVICE_AFFINITY_ENV=$GPU_P" \
|
||||
CUDA_VISIBLE_DEVICES="$GPU_P" \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
||||
vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_P" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$PREFILL_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||
--max-model-len "$MAX_MODEL_LEN" \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
@@ -162,16 +143,15 @@ PIDS+=($!)
|
||||
###############################################################################
|
||||
# Decode worker
|
||||
###############################################################################
|
||||
env "$DEVICE_AFFINITY_ENV=$GPU_D" \
|
||||
CUDA_VISIBLE_DEVICES="$GPU_D" \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
||||
vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_D" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$DECODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||
--max-model-len "$MAX_MODEL_LEN" \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
|
||||
@@ -17,28 +17,11 @@ PROXY_PORT="${PROXY_PORT:-10001}"
|
||||
GPU_E="${GPU_E:-0}"
|
||||
GPU_PD="${GPU_PD:-1}"
|
||||
|
||||
# Device platform and affinity env name.
|
||||
# DEVICE_PLATFORM supports: cuda, xpu
|
||||
DEVICE_PLATFORM="${DEVICE_PLATFORM:-cuda}"
|
||||
if [[ -z "${DEVICE_AFFINITY_ENV:-}" ]]; then
|
||||
if [[ "${DEVICE_PLATFORM,,}" == "xpu" ]]; then
|
||||
DEVICE_AFFINITY_ENV="ZE_AFFINITY_MASK"
|
||||
else
|
||||
DEVICE_AFFINITY_ENV="CUDA_VISIBLE_DEVICES"
|
||||
fi
|
||||
fi
|
||||
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||
|
||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||
|
||||
# Serve args
|
||||
GPU_MEMORY_UTILIZATION_E="${GPU_MEMORY_UTILIZATION_E:-0.01}"
|
||||
GPU_MEMORY_UTILIZATION_PD="${GPU_MEMORY_UTILIZATION_PD:-0.7}"
|
||||
MAX_NUM_SEQS="${MAX_NUM_SEQS:-128}"
|
||||
MAX_MODEL_LEN="${MAX_MODEL_LEN:-32768}"
|
||||
|
||||
###############################################################################
|
||||
# Helpers
|
||||
###############################################################################
|
||||
@@ -103,14 +86,14 @@ mkdir -p "$EC_SHARED_STORAGE_PATH"
|
||||
###############################################################################
|
||||
# Encoder worker
|
||||
###############################################################################
|
||||
env "$DEVICE_AFFINITY_ENV=$GPU_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_E" \
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--port "$ENCODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
@@ -126,13 +109,12 @@ PIDS+=($!)
|
||||
###############################################################################
|
||||
# Prefill+Decode worker
|
||||
###############################################################################
|
||||
env "$DEVICE_AFFINITY_ENV=$GPU_PD" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION_PD" \
|
||||
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$PREFILL_DECODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs "$MAX_NUM_SEQS" \
|
||||
--max-model-len "$MAX_MODEL_LEN" \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path "${GIT_ROOT}"/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
|
||||
96
managed_alloc.cu
Normal file
96
managed_alloc.cu
Normal file
@@ -0,0 +1,96 @@
|
||||
// managed_alloc.cu - cudaMallocManaged allocator for PyTorch
|
||||
// Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC
|
||||
// Compatible with CUDA 13+ (uses cudaMemLocation API)
|
||||
//
|
||||
// Key design decisions for GH200 EGM:
|
||||
// 1. cudaMallocManaged → allocations can page-fault across HBM + EGM
|
||||
// 2. cudaMemAdviseSetPreferredLocation(GPU) → driver prefers keeping pages on GPU
|
||||
// 3. cudaMemAdviseSetAccessedBy(CPU) → CPU can access over C2C NVLink without
|
||||
// triggering page migration back to system RAM (critical: prevents OOM)
|
||||
// 4. Selective prefetching — small allocations (model weights, <2 GiB)
|
||||
// are prefetched to GPU so cuBLAS/cuDNN kernels can access them
|
||||
// directly from HBM. Large allocations (KV cache blocks) stay in
|
||||
// managed memory and page-fault on demand, since they're too large
|
||||
// to fit in HBM and attention ops can tolerate page faults.
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdio.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t)
|
||||
void* managed_malloc(size_t size, int device, cudaStream_t stream) {
|
||||
void* ptr = nullptr;
|
||||
|
||||
// Set the device before allocating
|
||||
cudaError_t err = cudaSetDevice(device);
|
||||
if (err != cudaSuccess) {
|
||||
fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n",
|
||||
device, cudaGetErrorString(err));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Use cudaMallocManaged - this is the key: allocations can page-fault
|
||||
// across HBM and LPDDR on GH200 with EGM enabled
|
||||
err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal);
|
||||
if (err != cudaSuccess) {
|
||||
fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s "
|
||||
"(size=%zu bytes / %.2f GiB)\n",
|
||||
cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// CUDA 13+ uses cudaMemLocation struct instead of int for device
|
||||
cudaMemLocation gpu_loc;
|
||||
gpu_loc.type = cudaMemLocationTypeDevice;
|
||||
gpu_loc.id = device;
|
||||
|
||||
// Advise: prefer GPU placement. On GH200 with EGM, the hardware will
|
||||
// migrate pages as needed, but the driver tries to keep them on GPU.
|
||||
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, gpu_loc);
|
||||
|
||||
// Advise: CPU will access this memory too. On GH200, this sets up
|
||||
// remote mapping over C2C NVLink so CPU can read/write without
|
||||
// triggering page migration back to system RAM. This is CRITICAL
|
||||
// to prevent OOM on EGM systems where most system RAM was carved
|
||||
// out for the GPU.
|
||||
cudaMemLocation cpu_loc;
|
||||
cpu_loc.type = cudaMemLocationTypeHost;
|
||||
cpu_loc.id = cudaCpuDeviceId;
|
||||
cudaMemAdvise(ptr, size, cudaMemAdviseSetAccessedBy, cpu_loc);
|
||||
|
||||
// Selective prefetch: migrate pages to GPU for small allocations only.
|
||||
// Model weights (individual tensors) are typically <2 GiB and MUST be
|
||||
// on GPU for cuBLAS GEMM operations — GPU compute kernels cannot
|
||||
// page-fault into managed memory during execution.
|
||||
// KV cache blocks are large and numerous; prefetching them all fills
|
||||
// HBM and causes subsequent allocations to fail.
|
||||
// The 2 GiB threshold separates "compute data" from "cache data".
|
||||
const size_t PREFETCH_THRESHOLD = 2ULL * 1024 * 1024 * 1024; // 2 GiB
|
||||
|
||||
if (size > 0 && size < PREFETCH_THRESHOLD) {
|
||||
err = cudaMemPrefetchAsync(ptr, size, gpu_loc, 0);
|
||||
if (err != cudaSuccess) {
|
||||
// Non-fatal: prefetch failure shouldn't prevent allocation.
|
||||
// Pages will still be migrated on demand.
|
||||
fprintf(stderr, "[managed_alloc] cudaMemPrefetchAsync warning: %s "
|
||||
"(size=%.2f GiB, will use on-demand migration)\n",
|
||||
cudaGetErrorString(err), (double)size / (1024.0*1024.0*1024.0));
|
||||
}
|
||||
}
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t)
|
||||
void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) {
|
||||
if (ptr != nullptr) {
|
||||
// Sync the stream before freeing to avoid use-after-free with
|
||||
// managed memory (in-flight page faults can race with deallocation).
|
||||
if (stream != nullptr) {
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
cudaFree(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -9,8 +9,8 @@ torchaudio==2.10.0
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.6.7
|
||||
flashinfer-cubin==0.6.7
|
||||
flashinfer-python==0.6.6
|
||||
flashinfer-cubin==0.6.6
|
||||
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
|
||||
# breaking changes in 1.19.0
|
||||
nvidia-cudnn-frontend>=1.13.0,<1.19.0
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# --- Test Infrastructure ---
|
||||
tblib
|
||||
pytest
|
||||
pytest_asyncio
|
||||
pytest-timeout
|
||||
pytest-cov
|
||||
pytest-forked
|
||||
@@ -9,13 +7,8 @@ pytest-rerunfailures
|
||||
pytest-shard
|
||||
|
||||
# --- Core Tools & Bindings ---
|
||||
|
||||
absl-py
|
||||
accelerate
|
||||
arctic-inference
|
||||
hf_transfer
|
||||
lm_eval[api]
|
||||
modelscope
|
||||
|
||||
# --- Audio Processing ---
|
||||
librosa
|
||||
|
||||
@@ -1,730 +1,42 @@
|
||||
# XPU Test Dependencies
|
||||
# NOTE: Base image already has common.txt + xpu.txt installed,
|
||||
# and vllm-openai stage has pytest, pytest-asyncio, lm-eval[api].
|
||||
# This file only adds incremental test-specific packages.
|
||||
|
||||
# Additional test infrastructure (pytest/pytest-asyncio already in base)
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements/xpu-test.in -o requirements/xpu-test.txt -c requirements/xpu.txt --python-version 3.12 --index-strategy unsafe-best-match
|
||||
absl-py==2.4.0
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# rouge-score
|
||||
accelerate==1.13.0
|
||||
# via -r requirements/xpu-test.in
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.4
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# fsspec
|
||||
# gpt-oss
|
||||
# lm-eval
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
albumentations==1.4.6
|
||||
# via -r requirements/xpu-test.in
|
||||
annotated-doc==0.0.4
|
||||
# via fastapi
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.13.0
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
arctic-inference==0.1.1
|
||||
# via -r requirements/xpu-test.in
|
||||
attrs==26.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# referencing
|
||||
audioread==3.0.1
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# librosa
|
||||
blobfile==3.0.0
|
||||
# via -r requirements/xpu-test.in
|
||||
bm25s==0.2.13
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# mteb
|
||||
bounded-pool-executor==0.0.3
|
||||
# via pqdm
|
||||
certifi==2026.2.25
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==2.0.0
|
||||
# via soundfile
|
||||
chardet==5.2.0
|
||||
# via mbstrdecoder
|
||||
charset-normalizer==3.4.6
|
||||
# via requests
|
||||
chz==0.4.0
|
||||
# via gpt-oss
|
||||
click==8.3.1
|
||||
# via
|
||||
# jiwer
|
||||
# nltk
|
||||
# schemathesis
|
||||
# uvicorn
|
||||
colorama==0.4.6
|
||||
# via sacrebleu
|
||||
coverage==7.13.5
|
||||
# via pytest-cov
|
||||
dataproperty==1.1.0
|
||||
# via
|
||||
# pytablewriter
|
||||
# tabledata
|
||||
datasets==4.8.4
|
||||
# via
|
||||
# evaluate
|
||||
# lm-eval
|
||||
# mteb
|
||||
decorator==5.2.1
|
||||
# via librosa
|
||||
dill==0.4.1
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# lm-eval
|
||||
# multiprocess
|
||||
docker==7.1.0
|
||||
# via gpt-oss
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
dpcpp-cpp-rt==2025.3.1
|
||||
# via
|
||||
# onemkl-sycl-blas
|
||||
# onemkl-sycl-dft
|
||||
# onemkl-sycl-lapack
|
||||
# onemkl-sycl-rng
|
||||
# onemkl-sycl-sparse
|
||||
# torch
|
||||
evaluate==0.4.6
|
||||
# via lm-eval
|
||||
fastapi==0.135.2
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# gpt-oss
|
||||
filelock==3.25.2
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# blobfile
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# modelscope
|
||||
# torch
|
||||
# transformers
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2026.2.0
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# huggingface-hub
|
||||
# torch
|
||||
gpt-oss==0.0.8
|
||||
# via -r requirements/xpu-test.in
|
||||
graphql-core==3.2.8
|
||||
# via hypothesis-graphql
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
harfile==0.4.0
|
||||
# via schemathesis
|
||||
hf-transfer==0.1.9
|
||||
# via -r requirements/xpu-test.in
|
||||
hf-xet==1.4.2
|
||||
# via huggingface-hub
|
||||
html2text==2025.4.15
|
||||
# via gpt-oss
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# datasets
|
||||
# schemathesis
|
||||
huggingface-hub==0.36.2
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# evaluate
|
||||
# sentence-transformers
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
hypothesis==6.151.10
|
||||
# via
|
||||
# hypothesis-graphql
|
||||
# hypothesis-jsonschema
|
||||
# schemathesis
|
||||
hypothesis-graphql==0.12.0
|
||||
# via schemathesis
|
||||
hypothesis-jsonschema==0.23.1
|
||||
# via schemathesis
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio==2.37.3
|
||||
# via scikit-image
|
||||
impi-rt==2021.17.0
|
||||
# via
|
||||
# oneccl
|
||||
# torch
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
intel-cmplr-lib-rt==2025.3.1
|
||||
# via
|
||||
# intel-sycl-rt
|
||||
# torch
|
||||
intel-cmplr-lib-ur==2025.3.1
|
||||
# via
|
||||
# intel-openmp
|
||||
# intel-sycl-rt
|
||||
# torch
|
||||
intel-cmplr-lic-rt==2025.3.1
|
||||
# via
|
||||
# intel-opencl-rt
|
||||
# intel-sycl-rt
|
||||
# torch
|
||||
intel-opencl-rt==2025.3.1
|
||||
# via
|
||||
# dpcpp-cpp-rt
|
||||
# onemkl-sycl-blas
|
||||
# onemkl-sycl-dft
|
||||
# onemkl-sycl-lapack
|
||||
# onemkl-sycl-rng
|
||||
# onemkl-sycl-sparse
|
||||
# torch
|
||||
intel-openmp==2025.3.1
|
||||
# via
|
||||
# dpcpp-cpp-rt
|
||||
# mkl
|
||||
# torch
|
||||
intel-pti==0.15.0
|
||||
# via torch
|
||||
intel-sycl-rt==2025.3.1
|
||||
# via
|
||||
# dpcpp-cpp-rt
|
||||
# oneccl
|
||||
# torch
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# -c requirements/xpu.txt
|
||||
# lm-eval
|
||||
# torch
|
||||
jiwer==4.0.0
|
||||
# via -r requirements/xpu-test.in
|
||||
joblib==1.5.3
|
||||
# via
|
||||
# librosa
|
||||
# nltk
|
||||
# scikit-learn
|
||||
jsonlines==4.0.0
|
||||
# via lm-eval
|
||||
jsonschema==4.26.0
|
||||
# via
|
||||
# hypothesis-jsonschema
|
||||
# mistral-common
|
||||
# schemathesis
|
||||
jsonschema-rs==0.45.0
|
||||
# via schemathesis
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
junit-xml==1.9
|
||||
# via schemathesis
|
||||
lazy-loader==0.5
|
||||
# via
|
||||
# librosa
|
||||
# scikit-image
|
||||
librosa==0.10.2.post1
|
||||
# via -r requirements/xpu-test.in
|
||||
llvmlite==0.44.0
|
||||
# via numba
|
||||
lm-eval==0.4.11
|
||||
# via -r requirements/xpu-test.in
|
||||
lxml==6.0.2
|
||||
# via
|
||||
# blobfile
|
||||
# gpt-oss
|
||||
# sacrebleu
|
||||
markdown-it-py==4.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
mbstrdecoder==1.1.4
|
||||
# via
|
||||
# dataproperty
|
||||
# pytablewriter
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.10.0
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/xpu-test.in
|
||||
mkl==2025.3.0
|
||||
# via
|
||||
# onemkl-sycl-blas
|
||||
# onemkl-sycl-dft
|
||||
# onemkl-sycl-lapack
|
||||
# onemkl-sycl-rng
|
||||
# onemkl-sycl-sparse
|
||||
# torch
|
||||
modelscope==1.35.3
|
||||
# via -r requirements/xpu-test.in
|
||||
more-itertools==10.8.0
|
||||
# via lm-eval
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
msgpack==1.1.2
|
||||
# via librosa
|
||||
mteb==2.12.7
|
||||
# via -r requirements/xpu-test.in
|
||||
multidict==6.7.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.19
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
networkx==3.6.1
|
||||
# via
|
||||
# scikit-image
|
||||
# torch
|
||||
nltk==3.9.4
|
||||
# via rouge-score
|
||||
num2words==0.5.14
|
||||
# via -r requirements/xpu-test.in
|
||||
numba==0.61.2
|
||||
# via
|
||||
# -c requirements/xpu.txt
|
||||
# librosa
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# albumentations
|
||||
# bm25s
|
||||
# datasets
|
||||
# evaluate
|
||||
# imageio
|
||||
# librosa
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
# mteb
|
||||
# numba
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pytrec-eval-terrier
|
||||
# rouge-score
|
||||
# sacrebleu
|
||||
# scikit-image
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# sentence-transformers
|
||||
# soundfile
|
||||
# soxr
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
oneccl==2021.17.1
|
||||
# via
|
||||
# oneccl-devel
|
||||
# torch
|
||||
oneccl-devel==2021.17.1
|
||||
# via torch
|
||||
onemkl-license==2025.3.0
|
||||
# via
|
||||
# mkl
|
||||
# torch
|
||||
onemkl-sycl-blas==2025.3.0
|
||||
# via
|
||||
# onemkl-sycl-lapack
|
||||
# onemkl-sycl-sparse
|
||||
# torch
|
||||
onemkl-sycl-dft==2025.3.0
|
||||
# via torch
|
||||
onemkl-sycl-lapack==2025.3.0
|
||||
# via torch
|
||||
onemkl-sycl-rng==2025.3.0
|
||||
# via torch
|
||||
onemkl-sycl-sparse==2025.3.0
|
||||
# via torch
|
||||
openai-harmony==0.0.8
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# gpt-oss
|
||||
opencv-python-headless==4.13.0.92
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# albumentations
|
||||
# mistral-common
|
||||
packaging==26.0
|
||||
# via
|
||||
# -c requirements/xpu.txt
|
||||
# accelerate
|
||||
# datasets
|
||||
# evaluate
|
||||
# huggingface-hub
|
||||
# lazy-loader
|
||||
# modelscope
|
||||
# pooch
|
||||
# pytest
|
||||
# pytest-rerunfailures
|
||||
# scikit-image
|
||||
# transformers
|
||||
# typepy
|
||||
pandas==3.0.1
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
pathvalidate==3.3.1
|
||||
# via pytablewriter
|
||||
pillow==12.1.1
|
||||
# via
|
||||
# imageio
|
||||
# mistral-common
|
||||
# scikit-image
|
||||
# torchvision
|
||||
platformdirs==4.9.4
|
||||
# via pooch
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
polars==1.39.3
|
||||
# via mteb
|
||||
polars-runtime-32==1.39.3
|
||||
# via polars
|
||||
pooch==1.8.2
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# librosa
|
||||
portalocker==3.2.0
|
||||
# via sacrebleu
|
||||
pqdm==0.2.0
|
||||
# via -r requirements/xpu-test.in
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
psutil==7.2.2
|
||||
# via accelerate
|
||||
py==1.11.0
|
||||
# via pytest-forked
|
||||
pyarrow==23.0.1
|
||||
# via datasets
|
||||
pycountry==26.2.16
|
||||
# via pydantic-extra-types
|
||||
pycparser==3.0
|
||||
# via cffi
|
||||
pycryptodomex==3.23.0
|
||||
# via blobfile
|
||||
pydantic==2.12.5
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# albumentations
|
||||
# fastapi
|
||||
# gpt-oss
|
||||
# mistral-common
|
||||
# mteb
|
||||
# openai-harmony
|
||||
# pydantic-extra-types
|
||||
pydantic-core==2.41.5
|
||||
# via pydantic
|
||||
pydantic-extra-types==2.11.1
|
||||
# via mistral-common
|
||||
pyelftools==0.32
|
||||
# via triton-xpu
|
||||
pygments==2.20.0
|
||||
# via
|
||||
# pytest
|
||||
# rich
|
||||
pyrate-limiter==4.1.0
|
||||
# via schemathesis
|
||||
pystemmer==3.0.0
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# mteb
|
||||
pytablewriter==1.2.1
|
||||
# via lm-eval
|
||||
pytest==9.0.2
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# pytest-asyncio
|
||||
# pytest-cov
|
||||
# pytest-forked
|
||||
# pytest-rerunfailures
|
||||
# pytest-shard
|
||||
# pytest-timeout
|
||||
# schemathesis
|
||||
pytest-asyncio==1.3.0
|
||||
# via -r requirements/xpu-test.in
|
||||
pytest-cov==6.3.0
|
||||
# via -r requirements/xpu-test.in
|
||||
pytest-forked==1.6.0
|
||||
# via -r requirements/xpu-test.in
|
||||
pytest-rerunfailures==14.0
|
||||
# via -r requirements/xpu-test.in
|
||||
pytest-shard==0.1.2
|
||||
# via -r requirements/xpu-test.in
|
||||
pytest-timeout==2.3.1
|
||||
# via -r requirements/xpu-test.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# pandas
|
||||
# typepy
|
||||
pytrec-eval-terrier==0.5.10
|
||||
# via mteb
|
||||
pytz==2026.1.post1
|
||||
# via typepy
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# accelerate
|
||||
# albumentations
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# schemathesis
|
||||
# timm
|
||||
# transformers
|
||||
rapidfuzz==3.12.1
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# jiwer
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2026.3.32
|
||||
# via
|
||||
# nltk
|
||||
# sacrebleu
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.33.1
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# datasets
|
||||
# docker
|
||||
# evaluate
|
||||
# gpt-oss
|
||||
# huggingface-hub
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
# modelscope
|
||||
# mteb
|
||||
# pooch
|
||||
# schemathesis
|
||||
# starlette-testclient
|
||||
# tiktoken
|
||||
# transformers
|
||||
rich==14.3.3
|
||||
# via
|
||||
# mteb
|
||||
# schemathesis
|
||||
rouge-score==0.1.2
|
||||
# via lm-eval
|
||||
rpds-py==0.30.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
sacrebleu==2.6.0
|
||||
# via lm-eval
|
||||
safetensors==0.7.0
|
||||
# via
|
||||
# accelerate
|
||||
# timm
|
||||
# transformers
|
||||
schemathesis==4.14.2
|
||||
# via -r requirements/xpu-test.in
|
||||
scikit-image==0.26.0
|
||||
# via albumentations
|
||||
scikit-learn==1.8.0
|
||||
# via
|
||||
# albumentations
|
||||
# librosa
|
||||
# lm-eval
|
||||
# mteb
|
||||
# sentence-transformers
|
||||
scipy==1.17.1
|
||||
# via
|
||||
# albumentations
|
||||
# bm25s
|
||||
# librosa
|
||||
# mteb
|
||||
# pytrec-eval-terrier
|
||||
# scikit-image
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
sentence-transformers==5.3.0
|
||||
# via mteb
|
||||
setuptools==80.10.2
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -c requirements/xpu.txt
|
||||
# modelscope
|
||||
# pytablewriter
|
||||
# torch
|
||||
six==1.17.0
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# junit-xml
|
||||
# python-dateutil
|
||||
# rouge-score
|
||||
sortedcontainers==2.4.0
|
||||
# via hypothesis
|
||||
soundfile==0.13.1
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# librosa
|
||||
# mistral-common
|
||||
soxr==0.5.0.post1
|
||||
# via
|
||||
# -r requirements/xpu-test.in
|
||||
# librosa
|
||||
# mistral-common
|
||||
sqlitedict==2.1.0
|
||||
# via lm-eval
|
||||
starlette==1.0.0
|
||||
# via
|
||||
# fastapi
|
||||
# starlette-testclient
|
||||
starlette-testclient==0.4.1
|
||||
# via schemathesis
|
||||
structlog==25.5.0
|
||||
# via gpt-oss
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
tabledata==1.3.4
|
||||
# via pytablewriter
|
||||
tabulate==0.10.0
|
||||
# via sacrebleu
|
||||
tbb==2022.3.0
|
||||
# via
|
||||
# intel-opencl-rt
|
||||
# mkl
|
||||
# torch
|
||||
# uv pip compile /workspace/vllm/requirements/xpu-test.in -o /workspace/vllm/requirements/xpu-test.txt -c /workspace/vllm/requirements/xpu.txt --index-strategy unsafe-best-match --extra-index-url ${PIP_EXTRA_INDEX_URL} --python-version ${PYTHON_VERSION}
|
||||
tblib==3.1.0
|
||||
# via -r requirements/xpu-test.in
|
||||
tcmlib==1.4.1
|
||||
# via
|
||||
# tbb
|
||||
# torch
|
||||
# umf
|
||||
tcolorpy==0.1.7
|
||||
# via pytablewriter
|
||||
tenacity==9.1.4
|
||||
# via
|
||||
# gpt-oss
|
||||
# lm-eval
|
||||
# schemathesis
|
||||
termcolor==3.3.0
|
||||
# via gpt-oss
|
||||
threadpoolctl==3.6.0
|
||||
# via scikit-learn
|
||||
tifffile==2026.3.3
|
||||
# via scikit-image
|
||||
tiktoken==0.12.0
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# gpt-oss
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
pytest-timeout==2.3.1
|
||||
pytest-cov==6.3.0
|
||||
pytest-forked==1.6.0
|
||||
pytest-rerunfailures==14.0
|
||||
pytest-shard==0.1.2
|
||||
|
||||
arctic-inference==0.1.1
|
||||
|
||||
# Required for audio processing tests
|
||||
librosa==0.10.2.post1
|
||||
audioread==3.0.1
|
||||
soxr==0.5.0.post1
|
||||
pooch==1.8.2
|
||||
soundfile==0.13.1
|
||||
|
||||
# Required for Mistral's streaming tool parser
|
||||
blobfile==3.0.0
|
||||
rapidfuzz==3.12.1
|
||||
|
||||
# Required for Mistral's streaming tool parser and some evaluation scripts
|
||||
gpt-oss==0.0.8
|
||||
schemathesis==3.39.15
|
||||
jiwer==4.0.0
|
||||
bm25s==0.2.13
|
||||
pystemmer==3.0.0
|
||||
mteb[bm25s]>=2, <3
|
||||
num2words==0.5.14
|
||||
pqdm==0.2.0
|
||||
|
||||
# Required for some evaluation scripts
|
||||
timm==1.0.17
|
||||
# via -r requirements/xpu-test.in
|
||||
tokenizers==0.22.2
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# transformers
|
||||
torch==2.10.0+xpu
|
||||
# via
|
||||
# -c requirements/xpu.txt
|
||||
# accelerate
|
||||
# mteb
|
||||
# sentence-transformers
|
||||
# timm
|
||||
# torchvision
|
||||
torchvision==0.25.0+xpu
|
||||
# via timm
|
||||
tqdm==4.67.3
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# huggingface-hub
|
||||
# lm-eval
|
||||
# modelscope
|
||||
# mteb
|
||||
# nltk
|
||||
# pqdm
|
||||
# sentence-transformers
|
||||
# transformers
|
||||
transformers==4.57.6
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# sentence-transformers
|
||||
triton-xpu==3.6.0
|
||||
# via torch
|
||||
typepy==1.3.4
|
||||
# via
|
||||
# dataproperty
|
||||
# pytablewriter
|
||||
# tabledata
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# aiosignal
|
||||
# albumentations
|
||||
# anyio
|
||||
# chz
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
# mteb
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pydantic-extra-types
|
||||
# pytest-asyncio
|
||||
# referencing
|
||||
# schemathesis
|
||||
# sentence-transformers
|
||||
# starlette
|
||||
# torch
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# pydantic
|
||||
umf==1.0.2
|
||||
# via
|
||||
# intel-cmplr-lib-ur
|
||||
# torch
|
||||
urllib3==2.6.3
|
||||
# via
|
||||
# blobfile
|
||||
# docker
|
||||
# modelscope
|
||||
# requests
|
||||
uvicorn==0.42.0
|
||||
# via gpt-oss
|
||||
werkzeug==3.1.7
|
||||
# via schemathesis
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
xxhash==3.6.0
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
yarl==1.23.0
|
||||
# via aiohttp
|
||||
zstandard==0.25.0
|
||||
# via lm-eval
|
||||
albumentations==1.4.6
|
||||
mistral-common[image,audio]==1.9.1
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ...utils import compare_all_settings
|
||||
|
||||
@@ -108,10 +109,10 @@ def test_compile_correctness(
|
||||
tp_size = test_setting.tp_size
|
||||
attn_backend = test_setting.attn_backend
|
||||
method = test_setting.method
|
||||
if current_platform.device_count() < pp_size * tp_size:
|
||||
if cuda_device_count_stateless() < pp_size * tp_size:
|
||||
pytest.skip(
|
||||
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||
f"{current_platform.device_count()}"
|
||||
f"{cuda_device_count_stateless()}"
|
||||
)
|
||||
|
||||
final_args = [
|
||||
|
||||
@@ -412,7 +412,7 @@ def test_cudagraph_sizes_post_init(
|
||||
|
||||
with (
|
||||
ctx,
|
||||
patch.object(current_platform, "device_count", return_value=tp_size),
|
||||
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
|
||||
):
|
||||
kwargs = {}
|
||||
if cudagraph_capture_sizes is not None:
|
||||
@@ -577,6 +577,48 @@ def test_compile_sizes_padding_validation():
|
||||
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
|
||||
[
|
||||
# Normal capping: sizes filtered to <= num_blocks
|
||||
(
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
|
||||
512,
|
||||
200,
|
||||
[1, 2, 4, 8, 16, 32, 64, 128],
|
||||
128,
|
||||
),
|
||||
# No capping needed: num_blocks >= max
|
||||
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
|
||||
# Exact boundary: num_blocks == max (no capping)
|
||||
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
|
||||
# All sizes capped: num_blocks < smallest size
|
||||
([8, 16, 32], 32, 4, [], 0),
|
||||
# num_blocks <= 0: early return, no change
|
||||
([1, 2, 4], 4, 0, [1, 2, 4], 4),
|
||||
],
|
||||
)
|
||||
def test_adjust_cudagraph_sizes_for_mamba_cache(
|
||||
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
|
||||
):
|
||||
"""Test that cudagraph capture sizes are correctly capped to fit
|
||||
available Mamba cache blocks.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/34094
|
||||
"""
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=capture_sizes,
|
||||
max_cudagraph_capture_size=max_size,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
|
||||
assert config.cudagraph_capture_sizes == expected_sizes
|
||||
assert config.max_cudagraph_capture_size == expected_max
|
||||
# Invariant: last element == max_cudagraph_capture_size
|
||||
if expected_sizes:
|
||||
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size
|
||||
|
||||
|
||||
def test_inductor_asserts_default_disabled(monkeypatch):
|
||||
"""Test that inductor runtime asserts are disabled by default
|
||||
(INFO logging level) on torch < 2.12."""
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
@@ -18,20 +16,9 @@ from vllm.utils.system_utils import update_environment_variables
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _distributed_worker_wrapper(fn, env, world_size, args, rank, skip_queue):
|
||||
try:
|
||||
fn(env, world_size, *args)
|
||||
except BaseException as exc:
|
||||
if isinstance(exc, pytest.skip.Exception):
|
||||
skip_queue.put((rank, str(exc)))
|
||||
return
|
||||
raise
|
||||
|
||||
|
||||
def distributed_run(fn, world_size, *args):
|
||||
number_of_processes = world_size
|
||||
processes: list[mp.Process] = []
|
||||
skip_queue: mp.SimpleQueue = mp.SimpleQueue()
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env["RANK"] = str(i)
|
||||
@@ -40,32 +27,13 @@ def distributed_run(fn, world_size, *args):
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = mp.Process(
|
||||
target=_distributed_worker_wrapper,
|
||||
args=(fn, env, world_size, args, i, skip_queue),
|
||||
)
|
||||
p = mp.Process(target=fn, args=(env, world_size, *args))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
skipped: list[tuple[int, str]] = []
|
||||
while not skip_queue.empty():
|
||||
rank, reason = skip_queue.get()
|
||||
skipped.append((rank, reason))
|
||||
|
||||
if len(skipped) == number_of_processes:
|
||||
reason = skipped[0][1]
|
||||
pytest.skip(reason)
|
||||
if 0 < len(skipped) < number_of_processes:
|
||||
skipped_ranks = sorted(rank for rank, _ in skipped)
|
||||
raise AssertionError(
|
||||
"Distributed test had partial skips; expected either all ranks "
|
||||
f"to skip or none. Skipped ranks: {skipped_ranks}, "
|
||||
f"total ranks: {number_of_processes}"
|
||||
)
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
@@ -80,12 +48,7 @@ def set_env_vars_and_device(env: dict[str, str]) -> None:
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment()
|
||||
atexit.register(_destroy_process_group_if_initialized)
|
||||
|
||||
# Ensure each worker process has the same random seed
|
||||
random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _destroy_process_group_if_initialized() -> None:
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@@ -9,7 +9,6 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||
from vllm.distributed.eplb.rebalance_execute import (
|
||||
move_from_buffer,
|
||||
rearrange_expert_weights_inplace,
|
||||
@@ -131,10 +130,9 @@ def verify_expert_weights_after_shuffle(
|
||||
hidden_sizes: list[int],
|
||||
ep_rank: int,
|
||||
num_local_experts: int,
|
||||
) -> bool:
|
||||
):
|
||||
"""Verify the weights after shuffling are correct."""
|
||||
num_layers = len(expert_weights)
|
||||
ok = True
|
||||
|
||||
for layer in range(num_layers):
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
@@ -157,38 +155,29 @@ def verify_expert_weights_after_shuffle(
|
||||
dtype=actual_weights.dtype,
|
||||
)
|
||||
|
||||
if not torch.equal(actual_weights, expected_weights):
|
||||
ok = False
|
||||
actual_head = actual_weights[:8].detach().cpu().tolist()
|
||||
expected_head = expected_weights[:8].detach().cpu().tolist()
|
||||
print(
|
||||
"verify_expert_weights_after_shuffle failed: "
|
||||
f"rank={ep_rank}, "
|
||||
f"layer={layer}, weight_idx={weight_idx}, "
|
||||
f"local_expert={local_expert}, "
|
||||
f"expected_logical_expert={expected_logical_expert}, "
|
||||
f"actual_head={actual_head}, expected_head={expected_head}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return ok
|
||||
torch.testing.assert_close(
|
||||
actual_weights,
|
||||
expected_weights,
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"local expert {local_expert}: "
|
||||
f"weights do not match. "
|
||||
f"Expected logical expert {expected_logical_expert}",
|
||||
)
|
||||
|
||||
|
||||
def verify_redundant_experts_have_same_weights(
|
||||
expert_weights: list[list[torch.Tensor]],
|
||||
indices: torch.Tensor,
|
||||
hidden_sizes: list[int],
|
||||
ep_rank: int,
|
||||
world_size: int,
|
||||
num_local_experts: int,
|
||||
) -> bool:
|
||||
):
|
||||
"""
|
||||
Verify that all replicas of the same logical expert have the same weights.
|
||||
"""
|
||||
num_layers = len(expert_weights)
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
|
||||
ok = True
|
||||
for layer in range(num_layers):
|
||||
# Collect weights for all physical experts for each weight matrix
|
||||
all_weights: list[torch.Tensor] = []
|
||||
@@ -238,54 +227,14 @@ def verify_redundant_experts_have_same_weights(
|
||||
# Verify that current physical expert's weights match the
|
||||
# previously saved logical expert weights
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
if not torch.equal(
|
||||
torch.testing.assert_close(
|
||||
all_weights[weight_idx][physical_pos],
|
||||
logical_expert_weights[logical_expert_id][weight_idx],
|
||||
):
|
||||
ok = False
|
||||
actual_head = (
|
||||
all_weights[weight_idx][physical_pos][:8]
|
||||
.detach()
|
||||
.cpu()
|
||||
.tolist()
|
||||
)
|
||||
reference_head = (
|
||||
logical_expert_weights[logical_expert_id][weight_idx][:8]
|
||||
.detach()
|
||||
.cpu()
|
||||
.tolist()
|
||||
)
|
||||
print(
|
||||
"verify_redundant_experts_have_same_weights failed: "
|
||||
f"rank={ep_rank}, "
|
||||
f"layer={layer}, weight_idx={weight_idx}, "
|
||||
f"logical_expert={logical_expert_id}, "
|
||||
f"physical_pos={physical_pos}, "
|
||||
f"actual_head={actual_head}, "
|
||||
f"reference_head={reference_head}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def assert_verification_synced(local_ok: bool, msg: str) -> None:
|
||||
ok_tensor = torch.tensor([1 if local_ok else 0], device="cuda", dtype=torch.int32)
|
||||
torch.distributed.all_reduce(ok_tensor, op=torch.distributed.ReduceOp.MIN)
|
||||
assert bool(ok_tensor.item()), msg
|
||||
|
||||
|
||||
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
|
||||
try:
|
||||
return create_eplb_communicator(
|
||||
group_coordinator=group_coordinator,
|
||||
backend=backend,
|
||||
expert_weights=expert_weights,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to create EPLB communicator for backend={backend}: {exc}"
|
||||
) from exc
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"logical expert {logical_expert_id}: "
|
||||
f"Physical expert {physical_pos} has different weights"
|
||||
f"than expected",
|
||||
)
|
||||
|
||||
|
||||
def _test_async_transfer_layer_without_mtp_worker(
|
||||
@@ -294,7 +243,6 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
num_logical_experts: int,
|
||||
eplb_communicator: str,
|
||||
) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
@@ -306,8 +254,8 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group_coordinator = get_tp_group()
|
||||
ep_group = ep_group_coordinator.device_group
|
||||
tp_group = get_tp_group()
|
||||
ep_group = tp_group.device_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
@@ -350,13 +298,6 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
cuda_stream = torch.cuda.Stream(device=device)
|
||||
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend=eplb_communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
)
|
||||
communicator.set_stream(cuda_stream)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
|
||||
transfer_layer(
|
||||
@@ -365,7 +306,6 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffer=expert_buffer,
|
||||
ep_group=ep_group,
|
||||
communicator=communicator,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
)
|
||||
@@ -380,38 +320,24 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
ep_rank=ep_rank,
|
||||
)
|
||||
|
||||
local_ok = verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
local_ok = (
|
||||
verify_redundant_experts_have_same_weights(
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
and local_ok
|
||||
)
|
||||
assert_verification_synced(
|
||||
local_ok,
|
||||
"Async transfer verification failed on at least one rank. "
|
||||
"See logs for details.",
|
||||
)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_with_redundancy(
|
||||
env,
|
||||
world_size,
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
eplb_communicator: str,
|
||||
env, world_size, num_layers, num_local_experts, num_logical_experts
|
||||
) -> None:
|
||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||
# to expert parallel)
|
||||
@@ -425,8 +351,7 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group_coordinator = get_tp_group()
|
||||
ep_group = ep_group_coordinator.cpu_group
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
@@ -462,12 +387,6 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend=eplb_communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
)
|
||||
|
||||
# Execute weight rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
@@ -475,33 +394,24 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
communicator=communicator,
|
||||
)
|
||||
|
||||
# Verify the rearrangement result
|
||||
local_ok = verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
local_ok = (
|
||||
verify_redundant_experts_have_same_weights(
|
||||
# Verify the rearrangement result
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
and local_ok
|
||||
)
|
||||
assert_verification_synced(
|
||||
local_ok,
|
||||
"Rearrange verification failed on at least one rank. See logs for details.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -527,13 +437,8 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
(4, 8, 8, 16),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
|
||||
def test_rearrange_expert_weights_with_redundancy(
|
||||
world_size,
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
eplb_communicator,
|
||||
world_size, num_layers, num_local_experts, num_logical_experts
|
||||
):
|
||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||
|
||||
@@ -545,7 +450,6 @@ def test_rearrange_expert_weights_with_redundancy(
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
eplb_communicator,
|
||||
)
|
||||
|
||||
|
||||
@@ -560,8 +464,7 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group_coordinator = get_tp_group()
|
||||
ep_group = ep_group_coordinator.cpu_group
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
@@ -591,40 +494,24 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend="torch_nccl",
|
||||
expert_weights=expert_weights[0],
|
||||
)
|
||||
|
||||
# Execute rearrangement (should be no change)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
ep_group,
|
||||
communicator,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
local_ok = True
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
if not torch.equal(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
):
|
||||
local_ok = False
|
||||
print(
|
||||
"test_rearrange_expert_weights_no_change failed: "
|
||||
f"layer={layer}, weight_idx={weight_idx}",
|
||||
flush=True,
|
||||
# Verify that the weights have not changed
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg=f"""Layer {layer}, weight {weight_idx}
|
||||
should remain unchanged""",
|
||||
)
|
||||
assert_verification_synced(
|
||||
local_ok,
|
||||
"No-change EPLB verification failed on at least one rank.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -633,13 +520,11 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
(2, 2, 2, 3),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
|
||||
def test_async_transfer_layer_without_mtp(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
num_logical_experts: int,
|
||||
eplb_communicator: str,
|
||||
):
|
||||
"""Exercise async EPLB transfer path without MTP/spec decode."""
|
||||
|
||||
@@ -652,7 +537,6 @@ def test_async_transfer_layer_without_mtp(
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
eplb_communicator,
|
||||
)
|
||||
|
||||
|
||||
@@ -665,10 +549,7 @@ def test_rearrange_expert_weights_no_change(world_size):
|
||||
|
||||
if torch.accelerator.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(
|
||||
_test_rearrange_expert_weights_no_change,
|
||||
world_size,
|
||||
)
|
||||
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
@@ -682,8 +563,7 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group_coordinator = get_tp_group()
|
||||
ep_group = ep_group_coordinator.cpu_group
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
@@ -720,40 +600,23 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend="torch_nccl",
|
||||
expert_weights=expert_weights[0],
|
||||
)
|
||||
|
||||
# Execute profile mode rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
communicator,
|
||||
is_profile=True, # Profile mode
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
local_ok = True
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
if not torch.equal(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
):
|
||||
local_ok = False
|
||||
print(
|
||||
"test_rearrange_expert_weights_profile_mode failed: "
|
||||
f"layer={layer}, weight_idx={weight_idx}",
|
||||
flush=True,
|
||||
# In profile mode, the weights should remain unchanged
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg="In profile mode, the weights should remain unchanged",
|
||||
)
|
||||
assert_verification_synced(
|
||||
local_ok,
|
||||
"Profile-mode EPLB verification failed on at least one rank.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@@ -762,7 +625,4 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
|
||||
if torch.accelerator.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(
|
||||
_test_rearrange_expert_weights_profile_mode,
|
||||
world_size,
|
||||
)
|
||||
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
@@ -20,7 +21,7 @@ from ..utils import multi_gpu_test
|
||||
@ray.remote
|
||||
class _CUDADeviceCountStatelessTestActor:
|
||||
def get_count(self):
|
||||
return current_platform.device_count()
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user