Support tensor parallel (#2)
This commit is contained in:
@@ -11,7 +11,6 @@ class CacheEngine:
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
gpu_id: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
@@ -25,7 +24,6 @@ class CacheEngine:
|
||||
f'head_size ({head_size}) must be a multiple of 16.')
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.gpu_id = gpu_id
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@@ -39,8 +37,8 @@ class CacheEngine:
|
||||
self.cpu_cache = self.allocate_cpu_cache()
|
||||
|
||||
# Initialize the stream for caching operations.
|
||||
self.cache_stream = torch.cuda.Stream(device=gpu_id)
|
||||
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
|
||||
self.cache_stream = torch.cuda.Stream()
|
||||
assert self.cache_stream != torch.cuda.current_stream()
|
||||
# Initialize the events for stream synchronization.
|
||||
self.events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
|
||||
@@ -69,12 +67,12 @@ class CacheEngine:
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *key_block_shape),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
device="cuda",
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *value_block_shape),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
device="cuda",
|
||||
)
|
||||
gpu_cache.append((key_blocks, value_blocks))
|
||||
return gpu_cache
|
||||
|
||||
Reference in New Issue
Block a user