[Kernel] Expand MoE weight loading + Add Fused Marlin MoE Kernel (#7766)

Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
This commit is contained in:
Dipika Sikka
2024-08-27 18:07:09 -04:00
committed by GitHub
parent ed6f002d33
commit fc911880cc
16 changed files with 2383 additions and 86 deletions

View File

@@ -7,7 +7,8 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
@@ -332,19 +333,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
set_weight_attrs(w2_weight_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
@@ -357,19 +355,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None