diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d2418a7f8..4b07f9b53 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -919,8 +919,8 @@ __global__ void gather_and_maybe_dequant_cache( // SCALAR_T is the data type of the destination tensor. // CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. -#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ - vllm::gather_and_maybe_dequant_cache \ <<>>( \ reinterpret_cast(src_cache.data_ptr()), \ @@ -931,6 +931,12 @@ __global__ void gather_and_maybe_dequant_cache( dst_entry_stride, reinterpret_cast(scale.data_ptr()), \ seq_starts_ptr); +#define CALL_GATHER_CACHE_576(SCALAR_T, CACHE_T, KV_DTYPE) \ + CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 576) + +#define CALL_GATHER_CACHE_320(SCALAR_T, CACHE_T, KV_DTYPE) \ + CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 320) + // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence @@ -960,9 +966,10 @@ void gather_and_maybe_dequant_cache( TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, "seq_starts must be int32"); } - TORCH_CHECK(head_dim == 576, - "gather_and_maybe_dequant_cache only support the head_dim to 576 " - "for better performance") + TORCH_CHECK( + head_dim == 320 || head_dim == 576, + "gather_and_maybe_dequant_cache only support the head_dim to 320 or 576 " + "for better performance") TORCH_CHECK(src_cache.device() == dst.device(), "src_cache and dst must be on the same device"); @@ -987,7 +994,13 @@ void gather_and_maybe_dequant_cache( const int32_t* seq_starts_ptr = seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; - DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); + if (head_dim == 576) { + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, + CALL_GATHER_CACHE_576); + } else { + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, + CALL_GATHER_CACHE_320); + } } namespace vllm { diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 4ff1e590a..7c60a8a14 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -23,7 +23,7 @@ CACHE_LAYOUTS = ["NHD", "HND"] KV_SCALE_TYPES = ["tensor", "attn_head"] # Parameters for MLA tests. -KV_LORA_RANKS = [512] +KV_LORA_RANKS = [256, 512] QK_ROPE_HEAD_DIMS = [64] NUM_TOKENS_MLA = [42] BLOCK_SIZES_MLA = [16] @@ -627,6 +627,8 @@ def test_concat_and_cache_ds_mla( pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm") if dtype.itemsize != 2: pytest.skip("ds_mla only supports 16-bit input") + if kv_lora_rank != 512: + pytest.skip("fp8_ds_mla requires kv_lora_rank == 512") kv_cache_dtype = "fp8_ds_mla" set_random_seed(seed) torch.set_default_device(device) @@ -663,7 +665,8 @@ def test_concat_and_cache_ds_mla( ref_cache_32bit = ref_cache_slice.view(torch.float32) kv_c_data = kv_c[i] - for tile_idx in range(4): + num_tiles = kv_lora_rank // 128 + for tile_idx in range(num_tiles): tile_start = tile_idx * 128 tile_end = (tile_idx + 1) * 128 tile_data[:] = kv_c_data[tile_start:tile_end] diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index b1dc1a860..36ccc649f 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1148,7 +1148,7 @@ class MLACommonBackend(AttentionBackend): @classmethod def get_supported_head_sizes(cls) -> list[int]: - return [576] + return [320, 576] @classmethod def is_mla(cls) -> bool: