[Core] Support loading GGUF model (#5191)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -23,6 +24,14 @@ class QuantizeMethodBase(ABC):
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
# Not required functions
|
||||
def embedding(self, layer: torch.nn.Module, *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""Gather embeddings in the layer based on indices in the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
@@ -31,6 +40,21 @@ class QuantizeMethodBase(ABC):
|
||||
return
|
||||
|
||||
|
||||
def method_has_implemented_embedding(
|
||||
method_class: Type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
has been changed from the base implementation.
|
||||
"""
|
||||
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
|
||||
None)
|
||||
class_embedding = inspect.getattr_static(method_class, "embedding", None)
|
||||
|
||||
return (class_embedding is not None
|
||||
and class_embedding is not base_embedding)
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user