[MODEL] LoRA support for Jamba model (#11209)

Signed-off-by: Erez Schwartz <erezs@ai21.com>
This commit is contained in:
ErezSC42
2024-12-27 19:58:21 +02:00
committed by GitHub
parent 101418096f
commit 55509c2114
5 changed files with 132 additions and 32 deletions

View File

@@ -42,12 +42,14 @@ class MambaMixer(CustomOp):
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu"):
activation="silu",
is_lora_enabled: bool = False):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation
self.is_lora_enabled = is_lora_enabled
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
@@ -63,6 +65,7 @@ class MambaMixer(CustomOp):
self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
intermediate_size,
@@ -170,7 +173,13 @@ class MambaMixer(CustomOp):
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
@@ -222,6 +231,11 @@ class MambaMixer(CustomOp):
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states