[New Model] DeepSeek-V3.2 (Rebased to Main) (#25896)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: NickLucche <nlucches@redhat.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat> // FLT_MIN
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
@@ -396,6 +397,176 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void concat_and_cache_ds_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size, //
|
||||
const float* scale //
|
||||
) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
const int64_t dst_idx_start =
|
||||
block_idx * block_stride + block_offset * entry_stride;
|
||||
|
||||
// Create 4 tile scales in shared memory
|
||||
__shared__ float smem[20];
|
||||
float* shard_abs_max = smem;
|
||||
float* tile_scales = smem + 16;
|
||||
|
||||
// For the NoPE part, each tile of 128 elements is handled by 4 warps
|
||||
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
|
||||
// The first thread of the first warp in each tile writes the scale
|
||||
// value for the tile. The RoPE part (last 64 elements) is handled
|
||||
// by another 2 warps (64 threads).
|
||||
// So in total, we use 18 warps (576 threads) per block.
|
||||
|
||||
// Cast kv_cache to 16_bit for RoPE values
|
||||
scalar_t* kv_cache_16bit =
|
||||
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
|
||||
|
||||
// The last 64 threads handle the RoPE part
|
||||
if (threadIdx.x >= kv_lora_rank) {
|
||||
const int8_t pe_idx = threadIdx.x - kv_lora_rank;
|
||||
const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
|
||||
// RoPE values start after the packed 8-bit NoPE values and the
|
||||
// 32-bit scales
|
||||
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
|
||||
kv_cache_16bit[dst_idx] = k_pe[src_idx];
|
||||
return;
|
||||
}
|
||||
|
||||
// Determine the scale for each chunk of NoPE
|
||||
const int16_t tile_idx = threadIdx.x >> 7;
|
||||
const int16_t warp_idx = (threadIdx.x & 127) >> 5;
|
||||
const int16_t lane_idx = threadIdx.x & 31;
|
||||
|
||||
// Load the NoPE element for this thread into registers
|
||||
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
|
||||
const scalar_t src_val = kv_c[src_idx];
|
||||
|
||||
// Warp-level reduction to find the max absolute value in the warp
|
||||
float max_abs = fabsf(src_val);
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
#ifdef USE_ROCM
|
||||
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
|
||||
#else
|
||||
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
// The first lane of each warp in each tile writes the max_abs of this part
|
||||
// of the tile to shared memory
|
||||
if (lane_idx == 0) {
|
||||
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// The first lane of the first warp in each tile computes the scale for the
|
||||
// tile and writes it to shared memory and to kv_cache
|
||||
if (warp_idx == 0 && lane_idx == 0) {
|
||||
float4 shard_abs_max_vec =
|
||||
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
|
||||
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
|
||||
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
|
||||
448.f;
|
||||
|
||||
// Avoid division by zero in `scaled_convert`
|
||||
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
|
||||
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
|
||||
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
|
||||
kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now all threads in the block scale and write their element
|
||||
const float scale_val = tile_scales[tile_idx];
|
||||
const int64_t dst_idx = dst_idx_start + threadIdx.x;
|
||||
kv_cache[dst_idx] =
|
||||
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
|
||||
src_val, scale_val);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void indexer_k_quant_and_cache_kernel(
|
||||
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int head_dim, // dimension of each head
|
||||
const int quant_block_size, // quantization block size
|
||||
const int cache_block_size, // cache block size
|
||||
const int cache_stride, // stride for each token in kv_cache
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
|
||||
threadIdx.y * blockDim.x + threadIdx.x) *
|
||||
VEC_SIZE;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
const int64_t block_idx = slot_idx / cache_block_size;
|
||||
const int64_t block_offset = slot_idx % cache_block_size;
|
||||
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
|
||||
return;
|
||||
}
|
||||
|
||||
float2 k_val = (reinterpret_cast<const float2*>(
|
||||
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
|
||||
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
|
||||
float amax = 0.0f;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Reduced amax
|
||||
for (int mask = 16; mask > 0; mask /= 2) {
|
||||
#ifdef USE_ROCM
|
||||
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
|
||||
#else
|
||||
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
|
||||
#endif
|
||||
}
|
||||
__syncwarp();
|
||||
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||
if (use_ue8m0) {
|
||||
scale = exp2f(ceilf(log2f(scale)));
|
||||
}
|
||||
|
||||
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
|
||||
block_offset * head_dim + head_dim_idx;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
kv_cache[dst_offset + i] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
const int64_t dst_scale_idx =
|
||||
block_idx * cache_block_size * cache_stride +
|
||||
cache_block_size * head_dim +
|
||||
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
|
||||
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@@ -438,7 +609,7 @@ void reshape_and_cache(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
CALL_RESHAPE_AND_CACHE);
|
||||
}
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@@ -509,6 +680,18 @@ void reshape_and_cache_flash(
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
@@ -531,20 +714,44 @@ void concat_and_cache_mla(
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
|
||||
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
|
||||
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
|
||||
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
|
||||
TORCH_CHECK(kv_c.itemsize() == 2,
|
||||
"kv_c.itemsize() must be 2 for fp8_ds_mla");
|
||||
TORCH_CHECK(k_pe.itemsize() == 2,
|
||||
"k_pe.itemsize() must be 2 for fp8_ds_mla");
|
||||
} else {
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
}
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||
dim3 grid(num_tokens);
|
||||
// For the NoPE part, each tile of 128 elements is handled by 4 warps
|
||||
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
|
||||
// The first thread of the first warp in each tile writes the scale
|
||||
// value for the tile. The RoPE part (last 64 elements) is handled
|
||||
// by another 2 warps (64 threads).
|
||||
// So in total, we use 18 warps (576 threads) per block.
|
||||
dim3 block(576);
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_DS_MLA);
|
||||
} else {
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@@ -922,3 +1129,42 @@ void cp_gather_cache(
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
}
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(k.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
|
||||
cache_block_size, cache_stride, use_ue8m0);
|
||||
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
int64_t quant_block_size, // quantization block size
|
||||
const std::string& scale_fmt) {
|
||||
int num_tokens = k.size(0);
|
||||
int head_dim = k.size(1);
|
||||
int cache_block_size = kv_cache.size(1);
|
||||
int cache_stride = kv_cache.size(2);
|
||||
bool use_ue8m0 = scale_fmt == "ue8m0";
|
||||
|
||||
TORCH_CHECK(k.device() == kv_cache.device(),
|
||||
"k and kv_cache must be on the same device");
|
||||
TORCH_CHECK(k.device() == slot_mapping.device(),
|
||||
"k and slot_mapping must be on the same device");
|
||||
TORCH_CHECK(head_dim % quant_block_size == 0,
|
||||
"head_dim must be divisible by quant_block_size");
|
||||
|
||||
constexpr int vec_size = 4;
|
||||
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
|
||||
(quant_block_size * vec_size));
|
||||
dim3 block(32, vec_size);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
|
||||
CALL_INDEXER_K_QUANT_AND_CACHE);
|
||||
}
|
||||
Reference in New Issue
Block a user