Add 320 dimension size support to MLA (#36161)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
@@ -919,8 +919,8 @@ __global__ void gather_and_maybe_dequant_cache(
|
|||||||
// SCALAR_T is the data type of the destination tensor.
|
// SCALAR_T is the data type of the destination tensor.
|
||||||
// CACHE_T is the stored data type of kv-cache.
|
// CACHE_T is the stored data type of kv-cache.
|
||||||
// KV_DTYPE is the real data type of kv-cache.
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
|
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, ENTRY_SZ) \
|
||||||
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, 576, \
|
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, ENTRY_SZ, \
|
||||||
thread_block_size> \
|
thread_block_size> \
|
||||||
<<<grid, block, 0, stream>>>( \
|
<<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
||||||
@@ -931,6 +931,12 @@ __global__ void gather_and_maybe_dequant_cache(
|
|||||||
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
|
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
|
||||||
seq_starts_ptr);
|
seq_starts_ptr);
|
||||||
|
|
||||||
|
#define CALL_GATHER_CACHE_576(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||||
|
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 576)
|
||||||
|
|
||||||
|
#define CALL_GATHER_CACHE_320(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||||
|
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 320)
|
||||||
|
|
||||||
// Gather sequences from the cache into the destination tensor.
|
// Gather sequences from the cache into the destination tensor.
|
||||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||||
// - block_table contains the cache block indices for each sequence
|
// - block_table contains the cache block indices for each sequence
|
||||||
@@ -960,9 +966,10 @@ void gather_and_maybe_dequant_cache(
|
|||||||
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
|
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
|
||||||
"seq_starts must be int32");
|
"seq_starts must be int32");
|
||||||
}
|
}
|
||||||
TORCH_CHECK(head_dim == 576,
|
TORCH_CHECK(
|
||||||
"gather_and_maybe_dequant_cache only support the head_dim to 576 "
|
head_dim == 320 || head_dim == 576,
|
||||||
"for better performance")
|
"gather_and_maybe_dequant_cache only support the head_dim to 320 or 576 "
|
||||||
|
"for better performance")
|
||||||
|
|
||||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||||
"src_cache and dst must be on the same device");
|
"src_cache and dst must be on the same device");
|
||||||
@@ -987,7 +994,13 @@ void gather_and_maybe_dequant_cache(
|
|||||||
const int32_t* seq_starts_ptr =
|
const int32_t* seq_starts_ptr =
|
||||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
if (head_dim == 576) {
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
|
||||||
|
CALL_GATHER_CACHE_576);
|
||||||
|
} else {
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
|
||||||
|
CALL_GATHER_CACHE_320);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ CACHE_LAYOUTS = ["NHD", "HND"]
|
|||||||
KV_SCALE_TYPES = ["tensor", "attn_head"]
|
KV_SCALE_TYPES = ["tensor", "attn_head"]
|
||||||
|
|
||||||
# Parameters for MLA tests.
|
# Parameters for MLA tests.
|
||||||
KV_LORA_RANKS = [512]
|
KV_LORA_RANKS = [256, 512]
|
||||||
QK_ROPE_HEAD_DIMS = [64]
|
QK_ROPE_HEAD_DIMS = [64]
|
||||||
NUM_TOKENS_MLA = [42]
|
NUM_TOKENS_MLA = [42]
|
||||||
BLOCK_SIZES_MLA = [16]
|
BLOCK_SIZES_MLA = [16]
|
||||||
@@ -627,6 +627,8 @@ def test_concat_and_cache_ds_mla(
|
|||||||
pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
|
pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
|
||||||
if dtype.itemsize != 2:
|
if dtype.itemsize != 2:
|
||||||
pytest.skip("ds_mla only supports 16-bit input")
|
pytest.skip("ds_mla only supports 16-bit input")
|
||||||
|
if kv_lora_rank != 512:
|
||||||
|
pytest.skip("fp8_ds_mla requires kv_lora_rank == 512")
|
||||||
kv_cache_dtype = "fp8_ds_mla"
|
kv_cache_dtype = "fp8_ds_mla"
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@@ -663,7 +665,8 @@ def test_concat_and_cache_ds_mla(
|
|||||||
ref_cache_32bit = ref_cache_slice.view(torch.float32)
|
ref_cache_32bit = ref_cache_slice.view(torch.float32)
|
||||||
|
|
||||||
kv_c_data = kv_c[i]
|
kv_c_data = kv_c[i]
|
||||||
for tile_idx in range(4):
|
num_tiles = kv_lora_rank // 128
|
||||||
|
for tile_idx in range(num_tiles):
|
||||||
tile_start = tile_idx * 128
|
tile_start = tile_idx * 128
|
||||||
tile_end = (tile_idx + 1) * 128
|
tile_end = (tile_idx + 1) * 128
|
||||||
tile_data[:] = kv_c_data[tile_start:tile_end]
|
tile_data[:] = kv_c_data[tile_start:tile_end]
|
||||||
|
|||||||
@@ -1148,7 +1148,7 @@ class MLACommonBackend(AttentionBackend):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [576]
|
return [320, 576]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_mla(cls) -> bool:
|
def is_mla(cls) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user