Support FP8-E5M2 KV Cache (#2279)
Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -17,6 +17,7 @@ class EngineArgs:
|
||||
download_dir: Optional[str] = None
|
||||
load_format: str = 'auto'
|
||||
dtype: str = 'auto'
|
||||
kv_cache_dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
max_model_len: Optional[int] = None
|
||||
worker_use_ray: bool = False
|
||||
@@ -122,6 +123,14 @@ class EngineArgs:
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument(
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
default='auto',
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. Note FP8 is not supported when cuda version is '
|
||||
'lower than 11.8.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
@@ -269,7 +278,7 @@ class EngineArgs:
|
||||
self.max_context_len_to_capture)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
model_config.get_sliding_window())
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
|
||||
@@ -85,6 +85,7 @@ class LLMEngine:
|
||||
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"enforce_eager={model_config.enforce_eager}, "
|
||||
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
@@ -144,6 +145,7 @@ class LLMEngine:
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self._run_workers("init_model")
|
||||
@@ -234,6 +236,7 @@ class LLMEngine:
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
cache_config = copy.deepcopy(self.cache_config)
|
||||
|
||||
for rank, (worker, (node_id,
|
||||
_)) in enumerate(zip(self.workers,
|
||||
@@ -249,6 +252,7 @@ class LLMEngine:
|
||||
rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
cache_config=cache_config,
|
||||
))
|
||||
|
||||
driver_rank = 0
|
||||
@@ -261,6 +265,7 @@ class LLMEngine:
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
cache_config=cache_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@@ -306,6 +311,7 @@ class LLMEngine:
|
||||
block_size=self.cache_config.block_size,
|
||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
|
||||
Reference in New Issue
Block a user