[Model] Officially support Emu3 with Transformers backend (#21319)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,18 +23,14 @@ def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
return ((name, torch.empty(0)) for name in weight_names)
|
||||
|
||||
|
||||
def create_model_dummy_weights(
|
||||
repo: str,
|
||||
model_arch: str,
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
|
||||
"""
|
||||
Create weights from a dummy meta deserialized hf model with name conversion
|
||||
"""
|
||||
model_cls: PreTrainedModel = getattr(transformers, model_arch)
|
||||
config = AutoConfig.from_pretrained(repo)
|
||||
with torch.device("meta"):
|
||||
model: PreTrainedModel = model_cls._from_config(config)
|
||||
return model.named_parameters()
|
||||
return model_cls._from_config(config)
|
||||
|
||||
|
||||
def model_architectures_for_test() -> list[str]:
|
||||
@@ -70,14 +66,21 @@ def test_hf_model_weights_mapper(model_arch: str):
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
original_weights = create_repo_dummy_weights(model_id)
|
||||
hf_converted_weights = create_model_dummy_weights(model_id, model_arch)
|
||||
hf_dummy_model = create_dummy_model(model_id, model_arch)
|
||||
hf_converted_weights = hf_dummy_model.named_parameters()
|
||||
hf_converted_buffers = hf_dummy_model.named_buffers()
|
||||
mapper: WeightsMapper = model_cls.hf_to_vllm_mapper
|
||||
|
||||
mapped_original_weights = mapper.apply(original_weights)
|
||||
mapped_hf_converted_weights = mapper.apply(hf_converted_weights)
|
||||
mapped_hf_converted_buffers = mapper.apply(hf_converted_buffers)
|
||||
|
||||
ref_weight_names = set(map(lambda x: x[0], mapped_original_weights))
|
||||
weight_names = set(map(lambda x: x[0], mapped_hf_converted_weights))
|
||||
buffer_names = set(map(lambda x: x[0], mapped_hf_converted_buffers))
|
||||
|
||||
# Some checkpoints may have buffers, we ignore them for this test
|
||||
ref_weight_names -= buffer_names
|
||||
|
||||
weights_missing = ref_weight_names - weight_names
|
||||
weights_unmapped = weight_names - ref_weight_names
|
||||
|
||||
Reference in New Issue
Block a user