diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1d882eb87..75f5dfe2d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.linear import ( UnquantizedLinearMethod, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -46,6 +47,35 @@ from vllm.v1.kv_cache_interface import ( logger = init_logger(__name__) +def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool: + """Returns whether the quantization method should load quantized weights.""" + return quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod + ) + + +def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None: + """Sets default quantization scales for the layer.""" + if register_buffer: + layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32)) + layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32)) + layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32)) + layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32)) + else: + layer._k_scale.fill_(1.0) + layer._v_scale.fill_(1.0) + layer._q_scale.fill_(1.0) + layer._prob_scale.fill_(1.0) + + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + layer._q_scale_float = 1.0 + layer._k_scale_float = 1.0 + layer._v_scale_float = 1.0 + layer._prob_scale_float = 1.0 + + def _init_kv_cache_quant( layer: nn.Module, quant_config: QuantizationConfig | None, @@ -74,17 +104,21 @@ def _init_kv_cache_quant( # with the model weights. layer.kv_cache_dtype = kv_cache_dtype layer.calculate_kv_scales = calculate_kv_scales - layer._k_scale = torch.tensor(1.0, dtype=torch.float32) - layer._v_scale = torch.tensor(1.0, dtype=torch.float32) - layer._q_scale = torch.tensor(1.0, dtype=torch.float32) - layer._prob_scale = torch.tensor(1.0, dtype=torch.float32) - # We also keep q/k/v_scale on host (cpu) memory for attention - # backends that require the scales to be on host instead of on device. - # e.g. Flashinfer - layer._q_scale_float = 1.0 - layer._k_scale_float = 1.0 - layer._v_scale_float = 1.0 + # Note [Register q/k/v/prob scales in state dict] + # When calling model.to(device), only parameters/buffers in state dict are + # moved. If not registering q/k/v/prob scales in state dict, there would + # be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor + # on cpu. + # Registering in state dict means it interacts with weight loading. One edge + # case is when quant_method is None, or quant_method is UnquantizedLinearMethod + # (i.e., should_load_quant_weights(quant_method) == False). + # In this case, the checkpoint does not have the scales. We need to + # initialize the scales to 1.0 and update the scales after weight loading. + # This is espectially important when we load dummy weights first (providing + # wrong scales) and then load real weights (which misses scales and keeps the + # wrong scales from dummy load). + set_default_quant_scales(layer, register_buffer=True) # The output scale on host memory. This should be the input scale of # the quant op after this attention layer. @@ -93,9 +127,9 @@ def _init_kv_cache_quant( quant_method = ( quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None ) - if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod - ): + + # See [Note: Register q/k/v/prob scales in state dict] + if should_load_quant_weights(quant_method): assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 # checkpoint config and become the "auto" behavior @@ -169,10 +203,16 @@ class Attention(nn.Module, AttentionLayerBase): assert num_heads % num_kv_heads == 0, ( f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" ) + self.quant_config = quant_config + self.layer_name = prefix # Initialize KV cache quantization attributes _init_kv_cache_quant( - self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales + self, + self.quant_config, + self.layer_name, + kv_cache_dtype, + calculate_kv_scales, ) self.num_heads = num_heads @@ -249,7 +289,6 @@ class Attention(nn.Module, AttentionLayerBase): if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix self.attn_type = attn_type if kv_sharing_target_layer_name is not None: @@ -378,6 +417,17 @@ class Attention(nn.Module, AttentionLayerBase): def process_weights_after_loading(self, act_dtype: torch.dtype): self.impl.process_weights_after_loading(act_dtype) + # If we should not load quant weights, we initialize the scales to 1.0 + # as the default value. See [Note: Register q/k/v/prob scales in state dict] + # for more details. + quant_method = ( + self.quant_config.get_quant_method(self, prefix=self.layer_name) + if self.quant_config + else None + ) + if not should_load_quant_weights(quant_method): + set_default_quant_scales(self, register_buffer=False) + def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend @@ -453,10 +503,15 @@ class MLAAttention(nn.Module, AttentionLayerBase): kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False + self.quant_config = quant_config # Initialize KV cache quantization attributes _init_kv_cache_quant( - self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales + self, + self.quant_config, + self.layer_name, + kv_cache_dtype, + calculate_kv_scales, ) dtype = torch.get_default_dtype() @@ -586,6 +641,17 @@ class MLAAttention(nn.Module, AttentionLayerBase): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) + # If we should not load quant weights, we initialize the scales to 1.0 + # as the default value. See [Note: Register q/k/v/prob scales in state dict] + # for more details. + quant_method = ( + self.quant_config.get_quant_method(self, prefix=self.layer_name) + if self.quant_config + else None + ) + if not should_load_quant_weights(quant_method): + set_default_quant_scales(self, register_buffer=False) + def calc_kv_scales( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor ) -> None: