[MODEL] LoRA support for Jamba model (#11209)
Signed-off-by: Erez Schwartz <erezs@ai21.com>
This commit is contained in:
@@ -42,12 +42,14 @@ class MambaMixer(CustomOp):
|
||||
use_rms_norm: bool,
|
||||
rms_norm_has_weight: bool = True,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu"):
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False):
|
||||
super().__init__()
|
||||
self.time_step_rank = time_step_rank
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.use_rms_norm = use_rms_norm
|
||||
self.activation = activation
|
||||
self.is_lora_enabled = is_lora_enabled
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
@@ -63,6 +65,7 @@ class MambaMixer(CustomOp):
|
||||
self.in_proj = MergedColumnParallelLinear(hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=use_bias)
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
self.x_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -170,7 +173,13 @@ class MambaMixer(CustomOp):
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. input varying initialization of time_step, B and C
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
if self.is_lora_enabled:
|
||||
# lora kernel requires contiguous tensor
|
||||
ssm_parameters = self.x_proj(
|
||||
hidden_states.transpose(-2, -1).contiguous())[0]
|
||||
else:
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
time_step, B, C = torch.split(
|
||||
ssm_parameters,
|
||||
@@ -222,6 +231,11 @@ class MambaMixer(CustomOp):
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
||||
-1))[0]
|
||||
if self.is_lora_enabled:
|
||||
# lora kernel requires contiguous tensor
|
||||
contextualized_states = self.out_proj(
|
||||
scan_outputs.transpose(-2, -1).contiguous())[0]
|
||||
else:
|
||||
contextualized_states = self.out_proj(
|
||||
scan_outputs.transpose(-2, -1))[0]
|
||||
return contextualized_states
|
||||
|
||||
Reference in New Issue
Block a user