[Bugfix] Ignore lm_head when loading embedding models (#10719)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-28 03:05:29 +08:00
committed by GitHub
parent 197b4484a3
commit 9b4b150395
4 changed files with 8 additions and 0 deletions

View File

@@ -443,6 +443,8 @@ class BertEmbeddingModel(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
def _build_model(self,

View File

@@ -504,4 +504,6 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)

View File

@@ -689,6 +689,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:

View File

@@ -580,4 +580,6 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)