Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: HaiShaw <hixiao@gmail.com> Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu> Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com> Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com> Co-authored-by: guofangze <guofangze@kuaishou.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -41,11 +41,13 @@ from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
hf_model_weights_iterator,
|
||||
kv_cache_scales_loader)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
@@ -115,6 +117,15 @@ class LlamaAttention(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# This will be overwritten by model initialization if we are using it.
|
||||
# N.B. currently we only support per tensor scalar scaling factors
|
||||
# & only applicable to ROCm (AMD GPU).
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
self.kv_scale = 1.0
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
@@ -153,7 +164,8 @@ class LlamaAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
|
||||
self.kv_scale)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -402,3 +414,27 @@ class LlamaForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path, tp_rank, tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type):
|
||||
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.kv_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
Reference in New Issue
Block a user