[MISC] Fix Tensor Parallelism for Quantized Mamba Models with n_groups=1 (#33257)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -17,7 +17,6 @@ from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
@@ -40,6 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader,
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
@@ -280,13 +280,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
"then num_groups must equal 1."
|
||||
)
|
||||
|
||||
assert (
|
||||
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
|
||||
), (
|
||||
"Tensor parallel currently supported for quantized models only "
|
||||
"if tensor parallel world size divides num groups."
|
||||
)
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.activation = activation
|
||||
@@ -308,121 +301,94 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
||||
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
|
||||
|
||||
if n_groups % self.tp_size == 0:
|
||||
self.conv1d = MergedColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
],
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
# Use ColumnParallelLinear with custom weight loaders for both cases:
|
||||
# - When n_groups % tp_size == 0: standard sharding without duplication
|
||||
# - When n_groups == 1: groups are duplicated across TP ranks
|
||||
# The custom weight loader handles both cases correctly.
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.num_heads,
|
||||
],
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
else:
|
||||
# This is the n_groups == 1 case,
|
||||
# where we need to duplicate groups if TP>1.
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
self.in_proj = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
self.in_proj = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
# Configure shard settings for the custom weight loader:
|
||||
# - group_shard_settings handles group duplication when n_groups == 1
|
||||
# - When n_groups % tp_size == 0, extra=0 and duplicate_groups=False
|
||||
group_shard_settings = (
|
||||
self.groups_ssm_state_size, # expected model size
|
||||
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
|
||||
n_groups == 1, # duplicate groups when n_groups == 1
|
||||
)
|
||||
intermediate_settings = (intermediate_size, 0, False)
|
||||
head_settings = (self.num_heads, 0, False)
|
||||
|
||||
# - because in_proj is a concatenation of 3 weights, we
|
||||
# need to interleave them before sharding
|
||||
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
||||
# for conv1d.bias, covn1d.weight and in_proj.weight
|
||||
# - need to set these settings, to assign the groups
|
||||
# to the head shards
|
||||
group_shard_settings = (
|
||||
self.groups_ssm_state_size, # expected model size
|
||||
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
|
||||
n_groups == 1, # if there was only one group
|
||||
)
|
||||
intermediate_settings = (intermediate_size, 0, False)
|
||||
head_settings = (self.num_heads, 0, False)
|
||||
|
||||
# - the weight already has a "weight_loader" attribute
|
||||
# which set_weight_attrs will raise if we do not
|
||||
# delete before trying to override it
|
||||
# - ditto for the other two weights below
|
||||
delattr(self.conv1d.bias, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.bias,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
delattr(self.conv1d.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if quant_config is None:
|
||||
# - quant layers do not have a weight loader
|
||||
delattr(self.in_proj.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.in_proj.weight,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings, # for gate
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
head_settings, # for dt
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
# Apply custom weight loaders for conv1d (bias and weight)
|
||||
delattr(self.conv1d.bias, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.bias,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
delattr(self.conv1d.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Create the custom weight loader for in_proj
|
||||
mamba_loader = mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings, # for gate
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
head_settings, # for dt
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
|
||||
# Apply the custom weight loader to in_proj.weight
|
||||
# Works for both non-quantized (Parameter) and quantized
|
||||
# (ModelWeightParameter which extends BasevLLMParameter)
|
||||
if isinstance(self.in_proj.weight, BasevLLMParameter):
|
||||
# For BasevLLMParameter subclasses (quantized layers like FP8)
|
||||
self.in_proj.weight.weight_loader = mamba_loader
|
||||
else:
|
||||
# For standard Parameter (non-quantized layers)
|
||||
delattr(self.in_proj.weight, "weight_loader")
|
||||
set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader})
|
||||
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
|
||||
Reference in New Issue
Block a user