[Attention][Perf] Optimize cp_gather_and_upconvert_fp8_kv_cache - DeepSeek-v3.2 (#35290)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
70485a11bd
commit
2b28b9b269
@@ -995,75 +995,67 @@ namespace vllm {
|
||||
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
|
||||
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ seq_lens, // [BATCH]
|
||||
const int32_t* __restrict__ workspace_starts, // [BATCH]
|
||||
const int32_t block_size, const int32_t head_dim,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = workspace_starts[bid];
|
||||
const int32_t seq_len = seq_lens[bid];
|
||||
const int32_t tot_slots = seq_len;
|
||||
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
|
||||
__nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
|
||||
const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ workspace_starts, // [num_reqs]
|
||||
const int32_t num_reqs, const int32_t block_size,
|
||||
const int32_t total_tokens, const int64_t block_table_stride,
|
||||
const int64_t cache_block_stride, const int64_t cache_entry_stride,
|
||||
const int64_t dst_entry_stride) {
|
||||
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
|
||||
if (flat_warp_id >= total_tokens) return;
|
||||
const int lane_id = threadIdx.x & 31;
|
||||
|
||||
const int32_t split_start = split * split_slots;
|
||||
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
|
||||
|
||||
const bool is_active_split = (split_start < tot_slots);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = split_start;
|
||||
int32_t offset_div = offset / block_size;
|
||||
offset = offset % block_size;
|
||||
const int32_t* batch_block_table = block_table + batch_offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Process each token in this split
|
||||
for (int pid = split_start; pid < split_end; ++pid) {
|
||||
auto block_id = batch_block_table[offset_div];
|
||||
const uint8_t* token_ptr =
|
||||
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
|
||||
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
|
||||
|
||||
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
|
||||
const uint8_t* no_pe_ptr = token_ptr;
|
||||
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
|
||||
const __nv_bfloat16* rope_ptr =
|
||||
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
|
||||
|
||||
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
|
||||
if (tid < 512) {
|
||||
// FP8 dequantization
|
||||
const int tile = tid >> 7; // each tile is 128 elements
|
||||
const float scale = scales_ptr[tile];
|
||||
const uint8_t val = no_pe_ptr[tid];
|
||||
dst_ptr[tid] =
|
||||
fp8::scaled_convert<__nv_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
|
||||
} else if (tid < 576) {
|
||||
// Rope copy (64 bf16 elements)
|
||||
const int rope_idx = tid - 512;
|
||||
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
|
||||
}
|
||||
|
||||
// Move to next token
|
||||
offset += 1;
|
||||
if (offset == block_size) {
|
||||
offset_div += 1;
|
||||
offset = 0;
|
||||
}
|
||||
// Binary search to find which request owns this output token
|
||||
int lo = 0, hi = num_reqs - 1;
|
||||
while (lo < hi) {
|
||||
int mid = (lo + hi + 1) >> 1;
|
||||
if (workspace_starts[mid] <= flat_warp_id)
|
||||
lo = mid;
|
||||
else
|
||||
hi = mid - 1;
|
||||
}
|
||||
const int req_id = lo;
|
||||
|
||||
// Compute physical token address via block table
|
||||
const int out_token_id = flat_warp_id;
|
||||
const int token_offset = out_token_id - workspace_starts[req_id];
|
||||
const int cache_block_idx = token_offset / block_size;
|
||||
const int offset_in_block = token_offset % block_size;
|
||||
const int physical_block =
|
||||
block_table[req_id * block_table_stride + cache_block_idx];
|
||||
|
||||
const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
|
||||
offset_in_block * cache_entry_stride;
|
||||
|
||||
const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
|
||||
const int4 fp8_data = nope_src[lane_id];
|
||||
|
||||
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
|
||||
const float scale = scales_ptr[lane_id >> 3];
|
||||
|
||||
const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
|
||||
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
|
||||
#ifdef USE_ROCM
|
||||
const bf16_8_t bf16_lo =
|
||||
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
|
||||
const bf16_8_t bf16_hi =
|
||||
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
|
||||
#else
|
||||
const bf16_8_t bf16_lo =
|
||||
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
|
||||
const bf16_8_t bf16_hi =
|
||||
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
|
||||
#endif
|
||||
|
||||
__nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
|
||||
int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
|
||||
nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
|
||||
nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
|
||||
|
||||
const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
|
||||
int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
|
||||
rope_dst[lane_id] = rope_src[lane_id];
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@@ -1257,15 +1249,16 @@ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
|
||||
}
|
||||
|
||||
// Decide on the number of splits based on the batch size
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(576);
|
||||
const int total_tokens = dst.size(0);
|
||||
constexpr int warps_per_block = 8;
|
||||
const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
|
||||
const int block_size_threads = warps_per_block * 32; // 256 threads
|
||||
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
|
||||
stream>>>(
|
||||
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
||||
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
||||
block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
|
||||
static_cast<int32_t>(batch_size), block_size, total_tokens,
|
||||
block_table_stride, cache_block_stride, cache_entry_stride,
|
||||
dst_entry_stride);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user