[ Misc ] non-uniform quantization via compressed-tensors for Llama (#6515)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
from typing import Dict, List, Protocol, Tuple
|
||||
|
||||
import torch
|
||||
from torch.func import functional_call
|
||||
@@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prefix="",
|
||||
) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
|
||||
class PPMissingLayer(torch.nn.Identity):
|
||||
"""
|
||||
A placeholder layer for missing layers in a pipeline parallel model.
|
||||
@@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
||||
num_hidden_layers: int,
|
||||
layer_fn: LayerFn,
|
||||
prefix: str,
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function, taking
|
||||
pipeline parallelism into account.
|
||||
@@ -131,8 +142,8 @@ def make_layers(
|
||||
get_pp_group().world_size)
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)] + [
|
||||
maybe_offload_to_cpu(layer_fn())
|
||||
for _ in range(start_layer, end_layer)
|
||||
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
|
||||
for idx in range(start_layer, end_layer)
|
||||
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
||||
Reference in New Issue
Block a user