Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)

Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>
Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu>
Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com>
Co-authored-by: guofangze <guofangze@kuaishou.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Adrian Abeyta
2024-04-03 16:15:55 -05:00
committed by GitHub
parent 3dcb3e8b98
commit 2ff767b513
41 changed files with 2592 additions and 142 deletions

View File

@@ -41,11 +41,13 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
hf_model_weights_iterator,
kv_cache_scales_loader)
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip
class LlamaMLP(nn.Module):
@@ -115,6 +117,15 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
@@ -153,7 +164,8 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
output, _ = self.o_proj(attn_output)
return output
@@ -402,3 +414,27 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")