[ Misc ] non-uniform quantization via compressed-tensors for Llama (#6515)

This commit is contained in:
Robert Shaw
2024-07-18 22:39:18 -04:00
committed by GitHub
parent d4201e06d5
commit dbe5588554
11 changed files with 301 additions and 91 deletions

View File

@@ -1,4 +1,4 @@
from typing import Callable, Dict, List, Tuple
from typing import Dict, List, Protocol, Tuple
import torch
from torch.func import functional_call
@@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds
class LayerFn(Protocol):
def __call__(
self,
prefix="",
) -> torch.nn.Module:
...
class PPMissingLayer(torch.nn.Identity):
"""
A placeholder layer for missing layers in a pipeline parallel model.
@@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
@@ -131,8 +142,8 @@ def make_layers(
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn())
for _ in range(start_layer, end_layer)
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules