[BugFix] register quant scale tensors as buffer (#31395)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.linear import (
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
@@ -46,6 +47,35 @@ from vllm.v1.kv_cache_interface import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
|
||||
"""Returns whether the quantization method should load quantized weights."""
|
||||
return quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod
|
||||
)
|
||||
|
||||
|
||||
def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
|
||||
"""Sets default quantization scales for the layer."""
|
||||
if register_buffer:
|
||||
layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
else:
|
||||
layer._k_scale.fill_(1.0)
|
||||
layer._v_scale.fill_(1.0)
|
||||
layer._q_scale.fill_(1.0)
|
||||
layer._prob_scale.fill_(1.0)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
layer._q_scale_float = 1.0
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
layer._prob_scale_float = 1.0
|
||||
|
||||
|
||||
def _init_kv_cache_quant(
|
||||
layer: nn.Module,
|
||||
quant_config: QuantizationConfig | None,
|
||||
@@ -74,17 +104,21 @@ def _init_kv_cache_quant(
|
||||
# with the model weights.
|
||||
layer.kv_cache_dtype = kv_cache_dtype
|
||||
layer.calculate_kv_scales = calculate_kv_scales
|
||||
layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
layer._q_scale_float = 1.0
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
# Note [Register q/k/v/prob scales in state dict]
|
||||
# When calling model.to(device), only parameters/buffers in state dict are
|
||||
# moved. If not registering q/k/v/prob scales in state dict, there would
|
||||
# be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
|
||||
# on cpu.
|
||||
# Registering in state dict means it interacts with weight loading. One edge
|
||||
# case is when quant_method is None, or quant_method is UnquantizedLinearMethod
|
||||
# (i.e., should_load_quant_weights(quant_method) == False).
|
||||
# In this case, the checkpoint does not have the scales. We need to
|
||||
# initialize the scales to 1.0 and update the scales after weight loading.
|
||||
# This is espectially important when we load dummy weights first (providing
|
||||
# wrong scales) and then load real weights (which misses scales and keeps the
|
||||
# wrong scales from dummy load).
|
||||
set_default_quant_scales(layer, register_buffer=True)
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
@@ -93,9 +127,9 @@ def _init_kv_cache_quant(
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
|
||||
# See [Note: Register q/k/v/prob scales in state dict]
|
||||
if should_load_quant_weights(quant_method):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
@@ -169,10 +203,16 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
assert num_heads % num_kv_heads == 0, (
|
||||
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
self.layer_name = prefix
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(
|
||||
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
|
||||
self,
|
||||
self.quant_config,
|
||||
self.layer_name,
|
||||
kv_cache_dtype,
|
||||
calculate_kv_scales,
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
@@ -249,7 +289,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
@@ -378,6 +417,17 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# If we should not load quant weights, we initialize the scales to 1.0
|
||||
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
|
||||
# for more details.
|
||||
quant_method = (
|
||||
self.quant_config.get_quant_method(self, prefix=self.layer_name)
|
||||
if self.quant_config
|
||||
else None
|
||||
)
|
||||
if not should_load_quant_weights(quant_method):
|
||||
set_default_quant_scales(self, register_buffer=False)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
@@ -453,10 +503,15 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(
|
||||
self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
|
||||
self,
|
||||
self.quant_config,
|
||||
self.layer_name,
|
||||
kv_cache_dtype,
|
||||
calculate_kv_scales,
|
||||
)
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
@@ -586,6 +641,17 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# If we should not load quant weights, we initialize the scales to 1.0
|
||||
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
|
||||
# for more details.
|
||||
quant_method = (
|
||||
self.quant_config.get_quant_method(self, prefix=self.layer_name)
|
||||
if self.quant_config
|
||||
else None
|
||||
)
|
||||
if not should_load_quant_weights(quant_method):
|
||||
set_default_quant_scales(self, register_buffer=False)
|
||||
|
||||
def calc_kv_scales(
|
||||
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user