[Bugfix / Core] Prefix Caching Guards (merged with main) (#4846)

Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
Zhuohan Li
2024-05-27 15:18:17 -07:00
committed by GitHub
parent f17a1a8f96
commit 1102bef219
11 changed files with 167 additions and 44 deletions

View File

@@ -86,10 +86,8 @@ class Qwen2Attention(nn.Module):
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None,
rope_scaling: Optional[Tuple] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -112,7 +110,6 @@ class Qwen2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window if use_sliding_window else None
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -140,7 +137,6 @@ class Qwen2Attention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
@@ -164,7 +160,6 @@ class Qwen2DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2Config,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
@@ -173,18 +168,14 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
use_sliding_window = (config.use_sliding_window
and layer_idx < config.max_window_layers)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
cache_config=cache_config,
quant_config=quant_config,
sliding_window=config.sliding_window,
rope_scaling=rope_scaling)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
@@ -244,8 +235,8 @@ class Qwen2Model(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers)
Qwen2DecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -302,6 +293,18 @@ class Qwen2ForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = %s is less than "
"`num_hidden_layers` = %s. Please open an issue "
"to discuss this feature." % (
config.max_window_layers,
config.num_hidden_layers,
))
super().__init__()
self.config = config
self.quant_config = quant_config