[Bugfix][DeepSeek-V3.2] fix fp8 kvcache type cast (#33884)
Signed-off-by: Kebe <mail@kebe7jun.com>
This commit is contained in:
@@ -1234,8 +1234,13 @@ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
"src_cache and seq_lens must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
|
||||
"src_cache and workspace_starts must be on the same device");
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
|
||||
auto dtype = src_cache.scalar_type();
|
||||
TORCH_CHECK(
|
||||
dtype == at::ScalarType::Byte || // uint8
|
||||
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
|
||||
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
|
||||
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
|
||||
src_cache.dtype());
|
||||
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
|
||||
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
|
||||
|
||||
@@ -1244,14 +1249,21 @@ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
const uint8_t* src_ptr = nullptr;
|
||||
if (dtype == at::ScalarType::Byte) {
|
||||
src_ptr = src_cache.data_ptr<uint8_t>();
|
||||
} else {
|
||||
// float8_e4m3fn or float8_e5m2
|
||||
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
|
||||
}
|
||||
|
||||
// Decide on the number of splits based on the batch size
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(576);
|
||||
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
||||
src_cache.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
||||
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
||||
block_table_stride, cache_block_stride, cache_entry_stride,
|
||||
|
||||
Reference in New Issue
Block a user