[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
committed by
GitHub
parent
6e650f56a1
commit
e97f802b2d
@@ -98,7 +98,6 @@ class EngineArgs:
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
||||
dtype: str = 'auto'
|
||||
kv_cache_dtype: str = 'auto'
|
||||
quantization_param_path: Optional[str] = None
|
||||
seed: int = 0
|
||||
max_model_len: Optional[int] = None
|
||||
worker_use_ray: bool = False
|
||||
@@ -199,6 +198,8 @@ class EngineArgs:
|
||||
generation_config: Optional[str] = None
|
||||
enable_sleep_mode: bool = False
|
||||
|
||||
calculate_kv_scales: Optional[bool] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
@@ -350,17 +351,6 @@ class EngineArgs:
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help='Path to the JSON file containing the KV cache '
|
||||
'scaling factors. This should generally be supplied, when '
|
||||
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
|
||||
'default to 1.0, which may cause accuracy issues. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version '
|
||||
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
|
||||
'supported for common inference criteria.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=EngineArgs.max_model_len,
|
||||
@@ -962,6 +952,15 @@ class EngineArgs:
|
||||
help="Enable sleep mode for the engine. "
|
||||
"(only cuda platform is supported)")
|
||||
|
||||
parser.add_argument(
|
||||
'--calculate-kv-scales',
|
||||
action='store_true',
|
||||
help='This enables dynamic calculation of '
|
||||
'k_scale and v_scale when kv-cache-dtype is fp8. '
|
||||
'If calculate-kv-scales is false, the scales will '
|
||||
'be loaded from the model checkpoint if available. '
|
||||
'Otherwise, the scales will default to 1.0.')
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@@ -991,7 +990,6 @@ class EngineArgs:
|
||||
tokenizer_revision=self.tokenizer_revision,
|
||||
max_model_len=self.max_model_len,
|
||||
quantization=self.quantization,
|
||||
quantization_param_path=self.quantization_param_path,
|
||||
enforce_eager=self.enforce_eager,
|
||||
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
||||
max_logprobs=self.max_logprobs,
|
||||
@@ -1068,6 +1066,7 @@ class EngineArgs:
|
||||
sliding_window=model_config.get_sliding_window(),
|
||||
enable_prefix_caching=self.enable_prefix_caching,
|
||||
cpu_offload_gb=self.cpu_offload_gb,
|
||||
calculate_kv_scales=self.calculate_kv_scales,
|
||||
)
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
|
||||
Reference in New Issue
Block a user