[MODEL] FalconH1 (#18406)
Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae> Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
This commit is contained in:
committed by
GitHub
parent
61acfc45bc
commit
eca18691d2
@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
@CustomOp.register("mixer2_gated_rms_norm")
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
|
||||
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
|
||||
def __init__(self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
self.n_groups = full_hidden_size // self.group_size
|
||||
|
||||
self.variance_epsilon = eps
|
||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
assert self.full_hidden_size % self.tp_size== 0,\
|
||||
"Tensor parallel world size must divide hidden size."
|
||||
self.use_rms_norm = use_rms_norm
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert (self.full_hidden_size % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide hidden size."
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
# the input and then redundantly compute the RMSNorm.
|
||||
input_dtype = x.dtype
|
||||
x = x * nn.functional.silu(gate.to(torch.float32))
|
||||
if not self.use_rms_norm:
|
||||
return x
|
||||
|
||||
if self.n_groups == 1:
|
||||
if self.tp_size > 1:
|
||||
@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||
# Calculate the variance
|
||||
count = self.tp_size * x.shape[-1]
|
||||
variance = (global_sums / count)
|
||||
variance = global_sums / count
|
||||
|
||||
else:
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
@@ -106,6 +118,9 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
if not self.use_rms_norm:
|
||||
return x * nn.functional.silu(gate.to(torch.float32))
|
||||
|
||||
if self.tp_size > 1 or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
@@ -124,7 +139,7 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
|
||||
|
||||
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
@@ -182,13 +197,15 @@ def mamba_v2_sharded_weight_loader(
|
||||
# seem to handle slices well.
|
||||
# https://github.com/python/mypy/issues/2410
|
||||
param.data[
|
||||
boundary:(boundary + take), # type: ignore[misc]
|
||||
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc]
|
||||
loaded_start_idx + take)] # type: ignore[misc]
|
||||
boundary:(boundary + take),
|
||||
... # type: ignore[misc]
|
||||
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
|
||||
take) # type: ignore[misc]
|
||||
] # type: ignore[misc]
|
||||
|
||||
# move indexing boundaries
|
||||
boundary += shard_size
|
||||
loaded_boundary += (full_dim - extra)
|
||||
loaded_boundary += full_dim - extra
|
||||
|
||||
return loader
|
||||
|
||||
@@ -206,19 +223,22 @@ class MambaMixer2(CustomOp):
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# For TP, the sharding plan is as follows:
|
||||
@@ -238,17 +258,16 @@ class MambaMixer2(CustomOp):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert num_heads % self.tp_size == 0, \
|
||||
"Tensor parallel world size must divide num heads."
|
||||
assert (num_heads % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide num heads."
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
||||
(
|
||||
"If tensor parallel world size does not divide num_heads, "
|
||||
"then num_groups must equal 1."
|
||||
)
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
|
||||
"If tensor parallel world size does not divide num_heads, "
|
||||
"then num_groups must equal 1.")
|
||||
|
||||
assert self.tp_size == 1 or quant_config is None, \
|
||||
"Tensor parallel currently not supported for quantized models."
|
||||
assert (
|
||||
self.tp_size == 1 or quant_config is None
|
||||
), "Tensor parallel currently not supported for quantized models."
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.activation = activation
|
||||
@@ -265,8 +284,7 @@ class MambaMixer2(CustomOp):
|
||||
self.n_groups = n_groups + extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size)
|
||||
|
||||
self.conv_dim = (intermediate_size +
|
||||
2 * self.n_groups * ssm_state_size)
|
||||
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
@@ -279,11 +297,12 @@ class MambaMixer2(CustomOp):
|
||||
# doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=intermediate_size +
|
||||
self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config)
|
||||
self.in_proj = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# - because in_proj is a concatenation of 3 weights, we
|
||||
# need to interleave them before sharding
|
||||
@@ -305,7 +324,8 @@ class MambaMixer2(CustomOp):
|
||||
# - ditto for the otther two weights below
|
||||
delattr(self.conv1d.bias, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.bias, {
|
||||
self.conv1d.bias,
|
||||
{
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
@@ -316,18 +336,25 @@ class MambaMixer2(CustomOp):
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
delattr(self.conv1d.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight, {
|
||||
self.conv1d.weight,
|
||||
{
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader([
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
], self.tp_size, tp_rank)
|
||||
})
|
||||
mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if quant_config is None:
|
||||
# - quant layers do not have a weight loader
|
||||
@@ -345,8 +372,10 @@ class MambaMixer2(CustomOp):
|
||||
head_setings, # for dt
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank)
|
||||
})
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# - these are TPed by heads to reduce the size of the
|
||||
# temporal shape
|
||||
@@ -357,6 +386,7 @@ class MambaMixer2(CustomOp):
|
||||
))
|
||||
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.use_rms_norm = use_rms_norm
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||
a_weight_loader = composed_weight_loader(
|
||||
@@ -365,18 +395,25 @@ class MambaMixer2(CustomOp):
|
||||
set_weight_attrs(self.dt_bias,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.out_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config)
|
||||
self.out_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.norm = Mixer2RMSNormGated(intermediate_size,
|
||||
n_groups,
|
||||
self.use_rms_norm,
|
||||
eps=rms_norm_eps)
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward_cuda(
|
||||
@@ -384,6 +421,7 @@ class MambaMixer2(CustomOp):
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
@@ -401,6 +439,10 @@ class MambaMixer2(CustomOp):
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
|
||||
if mup_vector is not None:
|
||||
projected_states = projected_states * mup_vector
|
||||
|
||||
gate, hidden_states_B_C, dt = torch.split(
|
||||
projected_states,
|
||||
[
|
||||
@@ -561,6 +603,9 @@ class MambaMixer2(CustomOp):
|
||||
hidden_states = torch.vstack(ssd_output_list)
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
|
||||
# 5. Final linear projection
|
||||
|
||||
Reference in New Issue
Block a user