[Model] Correct Mixtral FP8 checkpoint loading (#5231)

This commit is contained in:
Cody Yu
2024-06-05 10:58:50 -07:00
committed by GitHub
parent ccd4f129e8
commit 5563a4dea8
2 changed files with 80 additions and 35 deletions

View File

@@ -300,14 +300,15 @@ def all_close_1d(x: torch.Tensor) -> bool:
def per_tensor_quantize(tensor: torch.Tensor, def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor: inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn) return qweight.to(torch.float8_e4m3fn)
def per_tensor_dequantize(tensor: torch.Tensor, def per_tensor_dequantize(
inv_scale: float) -> torch.Tensor: tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16) fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale dq_weight = fake_qweight * inv_scale
return dq_weight return dq_weight

View File

@@ -41,7 +41,9 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
per_tensor_dequantize,
per_tensor_quantize)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -98,16 +100,16 @@ class MixtralMoE(nn.Module):
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter( self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
torch.empty(self.num_total_experts, 2 * self.intermediate_size,
2 * self.intermediate_size, self.hidden_size,
self.hidden_size, dtype=params_dtype),
dtype=params_dtype)) requires_grad=False)
self.w2_weight = nn.Parameter( self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
torch.empty(self.num_total_experts, self.hidden_size,
self.hidden_size, self.intermediate_size,
self.intermediate_size, dtype=params_dtype),
dtype=params_dtype)) requires_grad=False)
set_weight_attrs(self.w13_weight, { set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
@@ -124,7 +126,10 @@ class MixtralMoE(nn.Module):
if self.use_fp8: if self.use_fp8:
# WEIGHT_SCALE (for fp8) # WEIGHT_SCALE (for fp8)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
2,
dtype=torch.float32), dtype=torch.float32),
requires_grad=False) requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
@@ -148,11 +153,11 @@ class MixtralMoE(nn.Module):
raise ValueError( raise ValueError(
"Found static activation scheme for checkpoint that " "Found static activation scheme for checkpoint that "
"was not serialized fp8.") "was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros( self.a13_scale = nn.Parameter(torch.ones(
self.num_total_experts, dtype=torch.float32), self.num_total_experts, dtype=torch.float32),
requires_grad=False) requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros( self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
self.num_total_experts, dtype=torch.float32), dtype=torch.float32),
requires_grad=False) requires_grad=False)
set_weight_attrs(self.a13_scale, { set_weight_attrs(self.a13_scale, {
@@ -175,8 +180,22 @@ class MixtralMoE(nn.Module):
shard_size:2 * shard_size, :] = loaded_weight[shard, :] shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"): if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard] param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
# Loading scales
if "act_scale" in weight_name or "w2.weight_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"act_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
elif "weight_scale" in weight_name:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert "w1" in weight_name or "w3" in weight_name
shard_id = 0 if "w1" in weight_name else 1
param_data[expert_id][shard_id] = loaded_weight
def process_weights_after_loading(self): def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading. # Fp8 is the only case where we need to process after loading.
@@ -189,6 +208,12 @@ class MixtralMoE(nn.Module):
dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data, w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
for expert in range(self.num_total_experts): for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[ w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant( expert] = ops.scaled_fp8_quant(
@@ -199,25 +224,44 @@ class MixtralMoE(nn.Module):
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales. else:
# Since state_dict has an act_scale per expert but our kernels # If checkpoint is fp8 + static, cleanup act_scales.
# are passed one act_scale shared across all experts. # Since state_dict has an act_scale per expert but our kernels
elif self.quant_config.activation_scheme == "static": # are passed one act_scale shared across all experts.
if self.a13_scale is None or self.a2_scale is None: if self.quant_config.activation_scheme == "static":
raise ValueError( if self.a13_scale is None or self.a2_scale is None:
"QuantConfig has static quantization, but found " raise ValueError(
"activation scales are None.") "QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(self.a13_scale) if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)): or not all_close_1d(self.a2_scale)):
print_warning_once( print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. " "Found act_scales that are not equal for "
"Using the maximum across experts for each layer. ") "fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
self.a13_scale = nn.Parameter(self.a13_scale.max(), self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False) requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(), self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False) requires_grad=False)
assert self.w13_scale is not None
shard_size = self.intermediate_size
max_w13_scales = self.w13_scale.max(dim=1).values
for expert_id in range(self.num_total_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
self.w13_weight[expert_id][start:start +
shard_size, :],
self.w13_scale[expert_id][shard_id])
self.w13_weight[expert_id][
start:start + shard_size, :] = per_tensor_quantize(
dq_weight, max_w13_scales[expert_id])
start += shard_size
self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape