Allocate more shared memory to attention kernel (#1154)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import enum
|
||||
from platform import uname
|
||||
import uuid
|
||||
from platform import uname
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from vllm import cuda_utils
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
GPU = enum.auto()
|
||||
@@ -25,6 +27,15 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
|
||||
max_shared_mem = cuda_utils.get_device_attribute(
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
|
||||
return int(max_shared_mem)
|
||||
|
||||
|
||||
def get_gpu_memory(gpu: int = 0) -> int:
|
||||
"""Returns the total memory of the GPU in bytes."""
|
||||
return torch.cuda.get_device_properties(gpu).total_memory
|
||||
|
||||
@@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.utils import get_gpu_memory
|
||||
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
||||
|
||||
|
||||
class Worker:
|
||||
@@ -136,6 +136,10 @@ class Worker:
|
||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||
self.cache_config = cache_config
|
||||
self.block_size = cache_config.block_size
|
||||
|
||||
_check_if_can_support_max_seq_len(self.scheduler_config.max_model_len,
|
||||
self.block_size)
|
||||
|
||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config)
|
||||
self.cache_events = self.cache_engine.events
|
||||
@@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
||||
return x + [0] * (max_len - len(x))
|
||||
|
||||
|
||||
def _check_if_can_support_max_seq_len(max_seq_len: int,
|
||||
block_size: int) -> None:
|
||||
# Follows the logic in
|
||||
# attention_kernels.cu::single_query_cached_kv_attention_launcher
|
||||
max_shared_mem = get_max_shared_memory_bytes()
|
||||
float32_bytes = torch.finfo(torch.float).bits // 8
|
||||
padded_max_seq_len = (
|
||||
(max_seq_len + block_size - 1) / block_size) * block_size
|
||||
# padded_max_seq_len + extra buffer
|
||||
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
|
||||
if padded_max_seq_len * float32_bytes > max_shared_mem:
|
||||
raise RuntimeError(
|
||||
f"vLLM cannot currently support max_model_len={max_seq_len} "
|
||||
f"with block_size={block_size} on GPU with compute "
|
||||
f"capability {torch.cuda.get_device_capability()} "
|
||||
f"(required shared memory {required_shared_mem} > "
|
||||
f"available shared memory {max_shared_mem}). "
|
||||
"This will be fixed in a future release.")
|
||||
|
||||
Reference in New Issue
Block a user