[Model] Support telechat2 (#10311)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: xiangw2 <xiangw2@chinatelecom.cn> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -501,8 +501,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
@@ -539,6 +538,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
|
||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user