[Model] Support telechat2 (#10311)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: xiangw2 <xiangw2@chinatelecom.cn>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
shunxing12345
2024-11-27 19:32:35 +08:00
committed by GitHub
parent e2251109c7
commit 1209261e93
8 changed files with 210 additions and 3 deletions

View File

@@ -501,8 +501,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config
self.lora_config = lora_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
@@ -539,6 +538,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
normalize=False,
softmax=False)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)