[Models] Add remaining model PP support (#7168)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Murali Andoorveedu
2024-10-03 19:56:58 -07:00
committed by GitHub
parent 303d44790a
commit 0f6d7a9a34
69 changed files with 2585 additions and 1344 deletions

View File

@@ -24,7 +24,7 @@ class WeightsGroup(UserDict):
when attempting to access a weight component that does not exist.
"""
def __getitem__(self, key: str) -> int:
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
try:
return super().__getitem__(key)
except KeyError as exc:
@@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
"""
Helper function to group weights with prefix
"""
@@ -183,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
class LayerFn(Protocol):
def __call__(
self,
prefix="",
) -> torch.nn.Module:
def __call__(self, prefix: str) -> torch.nn.Module:
...
@@ -319,8 +315,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
def make_empty_intermediate_tensors(
batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> IntermediateTensors:
return IntermediateTensors({
key: torch.zeros((batch_size, hidden_size),
dtype=dtype,
@@ -342,8 +340,14 @@ class LLMWrapper(nn.Module):
self.model_name = name
setattr(self, name, llm)
def forward(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name)(*args, **kwargs)
def __getattr__(self, key: str):
llm = super().__getattr__(self.model_name)
if key == self.model_name:
return llm
def embed_tokens(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)
return getattr(llm, key)
# We need to explicitly override this
def __call__(self, *args: Any, **kwargs: Any) -> Any:
llm = super().__getattr__(self.model_name)
return llm(*args, **kwargs)