[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)
The 2nd PR for #4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
@@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
|
||||
activation_scheme=activation_scheme)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
|
||||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return Fp8LinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@@ -251,6 +256,44 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return torch.narrow(output, 0, 0, x.shape[0])
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""Create "weight" (aka kv_scale) for an attention layer.
|
||||
|
||||
Args:
|
||||
layer: The layer that is using the QuantizeMethodBase factory.
|
||||
"""
|
||||
# Initialize the KV cache scale to 1.0 as the default value.
|
||||
# If the kv_scale appears in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
if layer.kv_cache_dtype != "auto":
|
||||
kv_scale = layer.kv_scale.to("cpu").tolist()
|
||||
if not isinstance(kv_scale, float):
|
||||
raise ValueError("Only support per-tensor scaling factor "
|
||||
"for fp8 KV cache")
|
||||
layer._kv_scale = kv_scale
|
||||
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
||||
print_warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
|
||||
"cause accuracy issues. Please make sure kv-cache scaling "
|
||||
"factor is available in the fp8 checkpoint.")
|
||||
del layer.kv_scale
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
|
||||
Reference in New Issue
Block a user