[MODEL] LoRA support for Jamba model (#11209)
Signed-off-by: Erez Schwartz <erezs@ai21.com>
This commit is contained in:
@@ -107,9 +107,11 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
is_lora_enabled: Optional[bool] = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.is_lora_enabled = is_lora_enabled
|
||||
self.mamba = MambaMixer(hidden_size= config.hidden_size,
|
||||
ssm_state_size = config.mamba_d_state,
|
||||
conv_kernel_size = config.mamba_d_conv,
|
||||
@@ -120,7 +122,9 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
use_bias = config.mamba_proj_bias,
|
||||
use_rms_norm=True,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act)
|
||||
activation=config.hidden_act,
|
||||
is_lora_enabled = self.is_lora_enabled
|
||||
)
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||
@@ -156,14 +160,13 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
|
||||
class JambaAttentionDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JambaConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
config: JambaConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -287,17 +290,18 @@ class JambaModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
|
||||
|
||||
def get_layer(prefix: str):
|
||||
layer_idx = int(prefix.rsplit(".", 1)[1])
|
||||
layer_class = ALL_DECODER_LAYER_TYPES[
|
||||
config.layers_block_type[layer_idx]]
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
return layer_class(config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
**extra_kwargs)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||
@@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"in_proj": ["in_proj"],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
|
||||
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
@@ -423,9 +426,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
if self.scheduler_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
not self.model_config.enforce_eager:
|
||||
if self.scheduler_config.max_num_seqs > \
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
self.max_batch_size = \
|
||||
vllm_config.compilation_config.max_capture_size
|
||||
else:
|
||||
@@ -446,7 +449,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
|
||||
Reference in New Issue
Block a user