[Misc] Update Fused MoE weight loading (#7334)

This commit is contained in:
Dipika Sikka
2024-08-13 14:57:45 -04:00
committed by GitHub
parent fb377d7e74
commit d3bdfd3ab9
6 changed files with 269 additions and 206 deletions

View File

@@ -290,23 +290,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w13_weight_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
set_weight_attrs(w2_weight_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
@@ -315,20 +321,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
a13_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("a13_scale", a13_scale)
set_weight_attrs(a13_scale, extra_weight_attrs)
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
a2_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("a2_scale", a2_scale)
set_weight_attrs(a2_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
else:
layer.a13_scale = None
layer.a2_scale = None
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
@@ -341,16 +353,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_scale = torch.nn.Parameter(torch.ones(
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
layer.num_experts,
dtype=torch.float32,
device=w13_weight.device),
requires_grad=False)
requires_grad=False)
for expert in range(layer.num_experts):
w13_weight[expert, :, :], layer.w13_scale[
w13_weight[expert, :, :], layer.w13_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_scale[
w2_weight[expert, :, :], layer.w2_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w2_weight.data[expert, :, :])
layer.w13_weight = torch.nn.Parameter(w13_weight,
@@ -366,40 +378,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.quant_config.activation_scheme == "static":
if layer.a13_scale is None or layer.a2_scale is None:
if (layer.w13_input_scale is None
or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.a13_scale)
or not all_close_1d(layer.a2_scale)):
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
requires_grad=False)
layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
requires_grad=False)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_scale is not None
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_scale.max(dim=1).values
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_scale[expert_id][shard_id])
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
return
def apply(self,
@@ -407,27 +420,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_moe
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_fp8=True,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
a1_scale=layer.a13_scale,
a2_scale=layer.a2_scale,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class Fp8KVCacheMethod(BaseKVCacheMethod):