[Bugfix][DeepSeek-V3.2] fix fp8 kvcache type cast (#33884)

Signed-off-by: Kebe <mail@kebe7jun.com>
This commit is contained in:
Kebe
2026-02-11 12:31:36 +09:00
committed by GitHub
parent b5dcb372e4
commit 5ee5c86eeb

View File

@@ -1234,8 +1234,13 @@ void cp_gather_and_upconvert_fp8_kv_cache(
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
auto dtype = src_cache.scalar_type();
TORCH_CHECK(
dtype == at::ScalarType::Byte || // uint8
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
src_cache.dtype());
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
@@ -1244,14 +1249,21 @@ void cp_gather_and_upconvert_fp8_kv_cache(
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
const uint8_t* src_ptr = nullptr;
if (dtype == at::ScalarType::Byte) {
src_ptr = src_cache.data_ptr<uint8_t>();
} else {
// float8_e4m3fn or float8_e5m2
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
}
// Decide on the number of splits based on the batch size
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(576);
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
src_cache.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
block_table_stride, cache_block_stride, cache_entry_stride,