diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 969c28c75..10d540a1d 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1234,8 +1234,13 @@ void cp_gather_and_upconvert_fp8_kv_cache( "src_cache and seq_lens must be on the same device"); TORCH_CHECK(src_cache.device() == workspace_starts.device(), "src_cache and workspace_starts must be on the same device"); - - TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8"); + auto dtype = src_cache.scalar_type(); + TORCH_CHECK( + dtype == at::ScalarType::Byte || // uint8 + dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3 + dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2 + "src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ", + src_cache.dtype()); TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); @@ -1244,14 +1249,21 @@ void cp_gather_and_upconvert_fp8_kv_cache( int64_t cache_entry_stride = src_cache.stride(1); int64_t dst_entry_stride = dst.stride(0); + const uint8_t* src_ptr = nullptr; + if (dtype == at::ScalarType::Byte) { + src_ptr = src_cache.data_ptr(); + } else { + // float8_e4m3fn or float8_e5m2 + src_ptr = reinterpret_cast(src_cache.data_ptr()); + } + // Decide on the number of splits based on the batch size int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; dim3 grid(batch_size, num_splits); dim3 block(576); vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( - src_cache.data_ptr(), - reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), + src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), block_table.data_ptr(), seq_lens.data_ptr(), workspace_starts.data_ptr(), block_size, head_dim, block_table_stride, cache_block_stride, cache_entry_stride,