[MODEL] LoRA support for Jamba model (#11209)

Signed-off-by: Erez Schwartz <erezs@ai21.com>
This commit is contained in:
ErezSC42
2024-12-27 19:58:21 +02:00
committed by GitHub
parent 101418096f
commit 55509c2114
5 changed files with 132 additions and 32 deletions

View File

@@ -107,9 +107,11 @@ class JambaMambaDecoderLayer(nn.Module):
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
is_lora_enabled: Optional[bool] = False,
**kwargs) -> None:
super().__init__()
self.config = config
self.is_lora_enabled = is_lora_enabled
self.mamba = MambaMixer(hidden_size= config.hidden_size,
ssm_state_size = config.mamba_d_state,
conv_kernel_size = config.mamba_d_conv,
@@ -120,7 +122,9 @@ class JambaMambaDecoderLayer(nn.Module):
use_bias = config.mamba_proj_bias,
use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act)
activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled
)
num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
@@ -156,14 +160,13 @@ class JambaMambaDecoderLayer(nn.Module):
class JambaAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -287,17 +290,18 @@ class JambaModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]]
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
return layer_class(config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
**extra_kwargs)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
@@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"k_proj",
"v_proj",
],
"in_proj": ["in_proj"],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
@@ -423,9 +426,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
@@ -446,7 +449,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(