diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f73431fd7..b184f6574 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -322,7 +322,7 @@ class TpKVTopology: # Figure out whether the first dimension of the cache is K/V # or num_blocks. This is used to register the memory regions correctly. kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=4, head_size=1 + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # we just mock num_blocks to 1 for the dimension check below.