[Kernel] W8A16 Int8 inside FusedMoE (#7415)

This commit is contained in:
Mor Zusman
2024-08-16 20:06:51 +03:00
committed by GitHub
parent e837b624f2
commit 7fc23be81c
15 changed files with 412 additions and 136 deletions

View File

@@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
return hidden_states
class JambaMLP(nn.Module):
def __init__(
self,
config: JambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class JambaMoE(nn.Module):
def __init__(self,
@@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
return hidden_states.view(orig_shape)
class JambaMLP(JambaMoE):
def __init__(self,
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(config,
num_experts=1,
top_k=1,
params_dtype=params_dtype,
tp_size=tp_size,
quant_config=quant_config)
class JambaMambaDecoderLayer(nn.Module):
def __init__(self,
@@ -884,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
@@ -907,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
if ".self_attn." in name:
name = name.replace(".self_attn", "")
if "feed_forward" in name and not _is_moe_layer(name):
## map MLP layers to expert with ID=0
name = name.replace("feed_forward", "feed_forward.experts.0")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -921,16 +906,21 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
for (
param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
weight_name,
shard_id=shard_id,
expert_id=expert_id)
break
@@ -943,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def _is_moe_layer(name: str):
return any(
[experts_name in name for experts_name in [
"experts",
"router",
]])