[Bugfix][DeepSeek-V3.2] fix fp8 kvcache type cast (#33884)
Signed-off-by: Kebe <mail@kebe7jun.com>
This commit is contained in:
@@ -1234,8 +1234,13 @@ void cp_gather_and_upconvert_fp8_kv_cache(
|
|||||||
"src_cache and seq_lens must be on the same device");
|
"src_cache and seq_lens must be on the same device");
|
||||||
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
|
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
|
||||||
"src_cache and workspace_starts must be on the same device");
|
"src_cache and workspace_starts must be on the same device");
|
||||||
|
auto dtype = src_cache.scalar_type();
|
||||||
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
|
TORCH_CHECK(
|
||||||
|
dtype == at::ScalarType::Byte || // uint8
|
||||||
|
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
|
||||||
|
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
|
||||||
|
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
|
||||||
|
src_cache.dtype());
|
||||||
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
|
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
|
||||||
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
|
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
|
||||||
|
|
||||||
@@ -1244,14 +1249,21 @@ void cp_gather_and_upconvert_fp8_kv_cache(
|
|||||||
int64_t cache_entry_stride = src_cache.stride(1);
|
int64_t cache_entry_stride = src_cache.stride(1);
|
||||||
int64_t dst_entry_stride = dst.stride(0);
|
int64_t dst_entry_stride = dst.stride(0);
|
||||||
|
|
||||||
|
const uint8_t* src_ptr = nullptr;
|
||||||
|
if (dtype == at::ScalarType::Byte) {
|
||||||
|
src_ptr = src_cache.data_ptr<uint8_t>();
|
||||||
|
} else {
|
||||||
|
// float8_e4m3fn or float8_e5m2
|
||||||
|
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
// Decide on the number of splits based on the batch size
|
// Decide on the number of splits based on the batch size
|
||||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||||
dim3 grid(batch_size, num_splits);
|
dim3 grid(batch_size, num_splits);
|
||||||
dim3 block(576);
|
dim3 block(576);
|
||||||
|
|
||||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
||||||
src_cache.data_ptr<uint8_t>(),
|
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||||
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
|
||||||
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
||||||
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
||||||
block_table_stride, cache_block_stride, cache_entry_stride,
|
block_table_stride, cache_block_stride, cache_entry_stride,
|
||||||
|
|||||||
Reference in New Issue
Block a user