[ROCm] Fix KV copy methods and auto-select attention backend for ROCm (#36845)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-16 03:07:27 -05:00
committed by GitHub
parent 8d3f8f485e
commit 911355e216
2 changed files with 75 additions and 17 deletions

View File

@@ -21,6 +21,11 @@
# MODEL_NAME - target model (default: meta-llama/Llama-3.1-8B-Instruct)
# NUM_SPEC_TOKENS - number of speculative tokens (default: 3)
# GPU_MEMORY_UTILIZATION - (default: 0.7)
# ATTENTION_BACKEND - attention backend to use
# Default: TRITON_ATTN on ROCm, FLASH_ATTN on NVIDIA
# ROCm options: TRITON_ATTN, ROCM_ATTN, ROCM_AITER_FA,
# ROCM_AITER_UNIFIED_ATTN
# NVIDIA options: FLASH_ATTN, FLASHINFER
set -x
# ── Model & spec decode config ──────────────────────────────────────────
@@ -51,6 +56,28 @@ GIT_ROOT=$(git rev-parse --show-toplevel)
SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
# ── Detect platform (NVIDIA vs ROCm) ────────────────────────────────────
if [[ "$SMI_BIN" == *"rocm"* ]]; then
GPU_PLATFORM="rocm"
GPU_DEVICE_VAR="HIP_VISIBLE_DEVICES"
else
GPU_PLATFORM="nvidia"
GPU_DEVICE_VAR="CUDA_VISIBLE_DEVICES"
fi
echo "Detected GPU platform: ${GPU_PLATFORM} (using ${GPU_DEVICE_VAR})"
# ── Attention backend config ─────────────────────────────────────────────
if [[ -z "${ATTENTION_BACKEND:-}" ]]; then
if [[ "$GPU_PLATFORM" == "rocm" ]]; then
ATTENTION_BACKEND="TRITON_ATTN"
else
ATTENTION_BACKEND="FLASH_ATTN"
fi
fi
echo "Using attention backend: ${ATTENTION_BACKEND}"
cleanup_instances() {
echo ""
echo "Cleaning up..."
@@ -84,13 +111,16 @@ wait_for_server() {
# ── Resolve GPU list ─────────────────────────────────────────────────────
if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then
IFS=',' read -ra ALL_GPUS <<< "$CUDA_VISIBLE_DEVICES"
# Accept either CUDA_VISIBLE_DEVICES or HIP_VISIBLE_DEVICES
VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-${HIP_VISIBLE_DEVICES:-}}"
if [[ -n "${VISIBLE_DEVICES}" ]]; then
IFS=',' read -ra ALL_GPUS <<< "$VISIBLE_DEVICES"
else
ALL_GPUS=()
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
if [[ "$GPU_PLATFORM" == "nvidia" ]]; then
num=$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)
elif [[ "$SMI_BIN" == *"rocm"* ]]; then
elif [[ "$GPU_PLATFORM" == "rocm" ]]; then
num=$($SMI_BIN -l | grep -c GPU)
else
num=1
@@ -100,7 +130,7 @@ fi
TOTAL_GPUS_NEEDED=$(( (NUM_PREFILL_INSTANCES * PREFILLER_TP_SIZE) + (NUM_DECODE_INSTANCES * DECODER_TP_SIZE) ))
if [[ ${#ALL_GPUS[@]} -lt $TOTAL_GPUS_NEEDED ]]; then
echo "FAIL: Need $TOTAL_GPUS_NEEDED GPUs but only have ${#ALL_GPUS[@]} (CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-not set})"
echo "FAIL: Need $TOTAL_GPUS_NEEDED GPUs but only have ${#ALL_GPUS[@]} (visible devices=${VISIBLE_DEVICES:-not set})"
exit 1
fi
@@ -119,12 +149,14 @@ run_test_for_device() {
echo "================================================================"
echo "NixlConnector PD + Spec Decode Acceptance Test (kv_buffer_device=${kv_device})"
echo "================================================================"
echo "Model: ${MODEL_NAME}"
echo "SD method: ${SD_METHOD}"
echo "SD model: ${SD_MODEL}"
echo "Spec tokens: ${NUM_SPEC_TOKENS}"
echo "KV buffer device: ${kv_device}"
echo "GPUs available: ${ALL_GPUS[*]}"
echo "Model: ${MODEL_NAME}"
echo "SD method: ${SD_METHOD}"
echo "SD model: ${SD_MODEL}"
echo "Spec tokens: ${NUM_SPEC_TOKENS}"
echo "KV buffer device: ${kv_device}"
echo "Attention backend: ${ATTENTION_BACKEND}"
echo "GPU platform: ${GPU_PLATFORM}"
echo "GPUs available: ${ALL_GPUS[*]}"
echo "================================================================"
local PREFILL_HOSTS=()
@@ -146,7 +178,8 @@ run_test_for_device() {
local SIDE_CHANNEL_PORT=$((5559 + i))
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
CUDA_VISIBLE_DEVICES=$GPU_ID \
env \
${GPU_DEVICE_VAR}=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
@@ -159,7 +192,7 @@ run_test_for_device() {
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config "$kv_config" \
--speculative-config "$PREFILL_SPEC_CONFIG" \
--attention-backend FLASH_ATTN &
--attention-backend $ATTENTION_BACKEND &
PREFILL_HOSTS+=("localhost")
PREFILL_PORTS+=("$PORT")
@@ -178,7 +211,8 @@ run_test_for_device() {
local SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
CUDA_VISIBLE_DEVICES=$GPU_ID \
env \
${GPU_DEVICE_VAR}=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
@@ -191,7 +225,7 @@ run_test_for_device() {
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config "$kv_config" \
--speculative-config "$DECODE_SPEC_CONFIG" \
--attention-backend FLASH_ATTN &
--attention-backend $ATTENTION_BACKEND &
DECODE_HOSTS+=("localhost")
DECODE_PORTS+=("$PORT")
@@ -218,7 +252,7 @@ run_test_for_device() {
sleep 5
# Run test
echo "Running spec decode acceptance test (kv_buffer_device=${kv_device})..."
echo "Running spec decode acceptance test (kv_buffer_device=${kv_device}, backend=${ATTENTION_BACKEND})..."
DECODE_PORT=${DECODE_PORTS[0]} \
TEST_MODEL=$MODEL_NAME \
python3 -m pytest -s -x "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py"
@@ -234,4 +268,4 @@ for device in $KV_BUFFER_DEVICES; do
run_test_for_device "$device"
done
echo "=== All spec decode acceptance tests passed ==="
echo "=== All spec decode acceptance tests passed (backend=${ATTENTION_BACKEND}) ==="

View File

@@ -851,6 +851,30 @@ class RocmPlatform(Platform):
"`dtype` flag in CLI, for example: --dtype=half."
)
@classmethod
def insert_blocks_to_device(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from src_cache to dst_cache on GPU."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
@classmethod
def swap_out_blocks_to_host(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from GPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.cpu()
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True