Upstream Llama4 Support to Main (#16113)

Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: Chris Thi <chris.c.thi@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Xiaodong Wang <xdwang@meta.com>
Signed-off-by: Yang Chen <yangche@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Lu Fang
2025-04-07 08:06:27 -07:00
committed by GitHub
parent 8017c8db7f
commit 55dcce91df
43 changed files with 2436 additions and 155 deletions

View File

@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
@@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
@@ -292,7 +294,7 @@ class LlamaModel(nn.Module):
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
"norm": "model.norm",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
@@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
@@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)