[v1] Refactor KVCacheConfig (#14079)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-03-21 19:56:27 +08:00
committed by GitHub
parent 61e8c18350
commit 93a00d7dde
10 changed files with 318 additions and 110 deletions

View File

@@ -62,14 +62,11 @@ class Executor(ExecutorBase):
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes
def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(output)
return output
def get_kv_cache_specs(self) -> list[KVCacheSpec]:
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec")
return output
@@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
def determine_available_memory(self) -> int: # in bytes
def determine_available_memory(self) -> list[int]: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
@@ -103,4 +100,4 @@ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item()
return [memory_tensor.item()]