[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
committed by
GitHub
parent
6e650f56a1
commit
e97f802b2d
@@ -3,6 +3,7 @@ import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -40,11 +41,16 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
if layer.kv_cache_dtype != "auto":
|
||||
# No need to process kv scales after loading if we are going to
|
||||
# calculate them on the fly.
|
||||
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if current_platform.is_rocm():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
@@ -58,6 +64,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if current_platform.is_rocm():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
|
||||
if not isinstance(k_scale, float) or not isinstance(
|
||||
v_scale, float):
|
||||
@@ -65,9 +74,11 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"for fp8 KV cache")
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale = k_scale
|
||||
layer._v_scale = v_scale
|
||||
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||
layer._k_scale.copy_(k_scale)
|
||||
layer._v_scale.copy_(v_scale)
|
||||
layer._k_scale_float = k_scale
|
||||
layer._v_scale_float = v_scale
|
||||
if (k_scale == 1.0 and v_scale == 1.0
|
||||
and "e5m2" not in layer.kv_cache_dtype):
|
||||
logger.warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||
|
||||
@@ -6,8 +6,7 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
|
||||
Tuple, Union)
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import filelock
|
||||
import gguf
|
||||
@@ -23,7 +22,6 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
@@ -496,47 +494,6 @@ def gguf_quant_weights_iterator(
|
||||
yield name, param
|
||||
|
||||
|
||||
def kv_cache_scales_loader(
|
||||
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
|
||||
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
|
||||
"""
|
||||
A simple utility to read in KV cache scaling factors that have been
|
||||
previously serialized to disk. Used by the model to populate the appropriate
|
||||
KV cache scaling factors. The serialization should represent a dictionary
|
||||
whose keys are the TP ranks and values are another dictionary mapping layers
|
||||
to their KV cache scaling factors.
|
||||
Keep this function in sync with the output of
|
||||
examples/other/fp8/extract_scales.py
|
||||
"""
|
||||
try:
|
||||
with open(filename) as f:
|
||||
context = {
|
||||
"model_type": model_type,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"tp_rank": tp_rank,
|
||||
"tp_size": tp_size,
|
||||
}
|
||||
schema_dct = json.load(f)
|
||||
schema = QuantParamSchema.model_validate(schema_dct,
|
||||
context=context)
|
||||
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
||||
return layer_scales_map.items()
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error("File or directory '%s' not found.", filename)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding JSON in file '%s'.", filename)
|
||||
except Exception:
|
||||
logger.exception("An error occurred while reading '%s'.", filename)
|
||||
# This section is reached if and only if any of the excepts are hit
|
||||
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||
# which ultimately defaults to 1.0 scales
|
||||
logger.warning(
|
||||
"Defaulting to KV cache scaling factors = 1.0 for all "
|
||||
"layers in TP rank %d as an error occurred during loading.", tp_rank)
|
||||
return []
|
||||
|
||||
|
||||
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||
|
||||
|
||||
@@ -30,8 +30,7 @@ from torch import nn
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
||||
|
||||
@@ -576,32 +574,3 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type,
|
||||
):
|
||||
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.transformer.h[layer_idx].attn
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||
layer_self_attn.attn._k_scale = scaling_factor
|
||||
layer_self_attn.attn._v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
@@ -29,8 +29,7 @@ from transformers import GraniteConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@@ -518,29 +516,3 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path, tp_rank, tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type):
|
||||
if not isinstance(self.model.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||
layer_self_attn.attn._k_scale = scaling_factor
|
||||
layer_self_attn.attn._v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
@@ -29,8 +29,7 @@ from transformers import LlamaConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@@ -43,9 +42,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@@ -440,32 +438,6 @@ class LlamaModel(nn.Module):
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path, tp_rank, tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type):
|
||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.layers[layer_idx].self_attn
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||
layer_self_attn.attn._k_scale = scaling_factor
|
||||
layer_self_attn.attn._v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
@@ -593,9 +565,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights)
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
# This function is used to remap the mistral format as
|
||||
# used by Mistral and Llama <=2
|
||||
def maybe_remap_mistral(
|
||||
|
||||
@@ -831,6 +831,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
# Skip writing kv-cache for the initial profiling run.
|
||||
if len(kv_cache.shape) > 1:
|
||||
i = torch.ones(1, dtype=torch.float32)
|
||||
if self.attn.backend in (_Backend.FLASH_ATTN,
|
||||
_Backend.FLASH_ATTN_VLLM_V1):
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
@@ -843,8 +844,8 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
attn_metadata.
|
||||
cross_slot_mapping, # type: ignore[union-attr]
|
||||
"auto",
|
||||
1.0,
|
||||
1.0,
|
||||
i,
|
||||
i,
|
||||
)
|
||||
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
@@ -853,7 +854,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||
PagedAttention.write_to_paged_cache(
|
||||
cached_k, cached_v, key_cache, value_cache,
|
||||
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
||||
attn_metadata.cross_slot_mapping, "auto", i, i)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported Attention backend {self.attn.backend} "
|
||||
|
||||
@@ -30,8 +30,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@@ -535,32 +533,3 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type,
|
||||
):
|
||||
if not isinstance(self.model.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||
layer_self_attn.attn._k_scale = scaling_factor
|
||||
layer_self_attn.attn._v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
Reference in New Issue
Block a user