Add 320 dimension size support to MLA (#36161)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize
2026-03-11 18:21:22 +01:00
committed by GitHub
parent 5efa206a8c
commit a5d06dc557
3 changed files with 25 additions and 9 deletions

View File

@@ -919,8 +919,8 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, 576, \
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, ENTRY_SZ) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, ENTRY_SZ, \
thread_block_size> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
@@ -931,6 +931,12 @@ __global__ void gather_and_maybe_dequant_cache(
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
seq_starts_ptr);
#define CALL_GATHER_CACHE_576(SCALAR_T, CACHE_T, KV_DTYPE) \
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 576)
#define CALL_GATHER_CACHE_320(SCALAR_T, CACHE_T, KV_DTYPE) \
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 320)
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
@@ -960,9 +966,10 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
TORCH_CHECK(head_dim == 576,
"gather_and_maybe_dequant_cache only support the head_dim to 576 "
"for better performance")
TORCH_CHECK(
head_dim == 320 || head_dim == 576,
"gather_and_maybe_dequant_cache only support the head_dim to 320 or 576 "
"for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
@@ -987,7 +994,13 @@ void gather_and_maybe_dequant_cache(
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
if (head_dim == 576) {
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
CALL_GATHER_CACHE_576);
} else {
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
CALL_GATHER_CACHE_320);
}
}
namespace vllm {