Support Longchat and RoPE scaling (#555)

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Lily Liu
2023-09-27 03:36:02 -07:00
committed by GitHub
parent cf5cb1e33e
commit 21877b0d75
4 changed files with 211 additions and 40 deletions

View File

@@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
@@ -92,6 +92,7 @@ class LlamaAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
@@ -135,7 +136,8 @@ class LlamaAttention(nn.Module):
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling)
def forward(
self,
@@ -165,6 +167,7 @@ class LlamaDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = LlamaAttention(
@@ -172,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)