[Misc] Move weights mapper (#11443)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -529,6 +529,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -577,8 +579,7 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
weights = hf_to_vllm_mapper.apply(weights)
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
weights = ((name, data) for name, data in weights
|
||||
if not name.startswith("lm_head."))
|
||||
self.model.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user