[ Misc ] non-uniform quantization via compressed-tensors for Llama (#6515)
This commit is contained in:
@@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
|
||||
config: GPT2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
|
||||
total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_attn",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
@@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
config: GPT2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
@@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
|
||||
config: GPT2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config, cache_config, quant_config)
|
||||
self.attn = GPT2Attention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim, config, quant_config)
|
||||
self.mlp = GPT2MLP(inner_dim,
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
|
||||
config: GPT2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda: GPT2Block(config, cache_config, quant_config))
|
||||
lambda prefix: GPT2Block(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.h")
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(config, cache_config, quant_config)
|
||||
self.transformer = GPT2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="transformer")
|
||||
self.lm_head = self.transformer.wte
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@@ -62,17 +62,20 @@ class LlamaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@@ -99,6 +102,7 @@ class LlamaAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -132,12 +136,14 @@ class LlamaAttention(nn.Module):
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -176,6 +182,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -203,6 +210,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -210,6 +218,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@@ -253,6 +262,7 @@ class LlamaModel(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -272,9 +282,11 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda: LlamaDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
lambda prefix: LlamaDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
@@ -370,7 +382,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.model = LlamaModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
|
||||
@@ -67,7 +67,8 @@ class MixtralMoE(nn.Module):
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None):
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
@@ -76,7 +77,8 @@ class MixtralMoE(nn.Module):
|
||||
num_experts,
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None)
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
self.experts = FusedMoE(num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
@@ -86,7 +88,8 @@ class MixtralMoE(nn.Module):
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size)
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
@@ -109,6 +112,7 @@ class MixtralAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -139,12 +143,14 @@ class MixtralAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -182,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
config: MixtralConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -194,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.block_sparse_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.block_sparse_moe")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
@@ -243,6 +252,7 @@ class MixtralModel(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -258,8 +268,11 @@ class MixtralModel(nn.Module):
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, lambda: MixtralDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config))
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MixtralDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -331,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.model = MixtralModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
from typing import Dict, List, Protocol, Tuple
|
||||
|
||||
import torch
|
||||
from torch.func import functional_call
|
||||
@@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prefix="",
|
||||
) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
|
||||
class PPMissingLayer(torch.nn.Identity):
|
||||
"""
|
||||
A placeholder layer for missing layers in a pipeline parallel model.
|
||||
@@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
||||
num_hidden_layers: int,
|
||||
layer_fn: LayerFn,
|
||||
prefix: str,
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function, taking
|
||||
pipeline parallelism into account.
|
||||
@@ -131,8 +142,8 @@ def make_layers(
|
||||
get_pp_group().world_size)
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)] + [
|
||||
maybe_offload_to_cpu(layer_fn())
|
||||
for _ in range(start_layer, end_layer)
|
||||
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
|
||||
for idx in range(start_layer, end_layer)
|
||||
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
||||
Reference in New Issue
Block a user