[perf] Add fused MLA QKV + strided layernorm (#21116)
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -259,6 +259,8 @@ class LinearBase(torch.nn.Module):
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[
|
||||
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
@@ -300,6 +302,12 @@ class ReplicatedLinear(LinearBase):
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
# If MergedReplicatedLinear, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = self.output_sizes
|
||||
else:
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
@@ -311,7 +319,8 @@ class ReplicatedLinear(LinearBase):
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_partition_sizes,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
@@ -367,6 +376,73 @@ class ReplicatedLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
class MergedReplicatedLinear(ReplicatedLinear):
|
||||
"""Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
super().__init__(input_size,
|
||||
sum(output_sizes),
|
||||
bias,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Union[Parameter, BasevLLMParameter],
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
assert loaded_shard_id is not None
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8LinearMethod, Fp8MoEMethod)
|
||||
assert self.quant_method is not None
|
||||
assert isinstance(self.quant_method,
|
||||
(Fp8LinearMethod, Fp8MoEMethod))
|
||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (
|
||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
||||
block_n)
|
||||
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
||||
block_n)
|
||||
elif isinstance(param, PerTensorScaleParameter):
|
||||
shard_offset = loaded_shard_id
|
||||
shard_size = 1
|
||||
else:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
|
||||
param[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user