[Models] Add remaining model PP support (#7168)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
303d44790a
commit
0f6d7a9a34
@@ -24,7 +24,7 @@ class WeightsGroup(UserDict):
|
||||
when attempting to access a weight component that does not exist.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str) -> int:
|
||||
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
@@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
|
||||
|
||||
def group_weights_with_prefix(
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
|
||||
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
|
||||
"""
|
||||
Helper function to group weights with prefix
|
||||
"""
|
||||
@@ -183,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prefix="",
|
||||
) -> torch.nn.Module:
|
||||
def __call__(self, prefix: str) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
|
||||
@@ -319,8 +315,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
||||
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
key: torch.zeros((batch_size, hidden_size),
|
||||
dtype=dtype,
|
||||
@@ -342,8 +340,14 @@ class LLMWrapper(nn.Module):
|
||||
self.model_name = name
|
||||
setattr(self, name, llm)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
return getattr(self, self.model_name)(*args, **kwargs)
|
||||
def __getattr__(self, key: str):
|
||||
llm = super().__getattr__(self.model_name)
|
||||
if key == self.model_name:
|
||||
return llm
|
||||
|
||||
def embed_tokens(self, *args, **kwargs) -> Any:
|
||||
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)
|
||||
return getattr(llm, key)
|
||||
|
||||
# We need to explicitly override this
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
llm = super().__getattr__(self.model_name)
|
||||
return llm(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user