[BugFix][TritonMLA] Process weights after model loading for GGUF (#14555)
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
This commit is contained in:
@@ -1330,11 +1330,14 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
local_model_path, gguf_weights_map):
|
local_model_path, gguf_weights_map):
|
||||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||||
|
|
||||||
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with target_device:
|
||||||
model = _initialize_model(vllm_config=vllm_config)
|
model = _initialize_model(vllm_config=vllm_config)
|
||||||
model.load_weights(
|
model.load_weights(
|
||||||
self._get_weights_iterator(local_model_path, gguf_weights_map))
|
self._get_weights_iterator(local_model_path, gguf_weights_map))
|
||||||
|
|
||||||
|
_process_weights_after_loading(model, model_config, target_device)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user