Remove LoRA bias support (#25807)
Signed-off-by: Ashwin Phadke <ashwinphadke12@rediffmail.com> Signed-off-by: Ashwin Phadke <23502062+ashwin-phadke@users.noreply.github.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -23,11 +23,6 @@ BADREQUEST_CASES = [
|
|||||||
{"r": 1024},
|
{"r": 1024},
|
||||||
"is greater than max_lora_rank",
|
"is greater than max_lora_rank",
|
||||||
),
|
),
|
||||||
(
|
|
||||||
"test_bias",
|
|
||||||
{"bias": "all"},
|
|
||||||
"Adapter bias cannot be used without bias_enabled",
|
|
||||||
),
|
|
||||||
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
|
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
|
||||||
(
|
(
|
||||||
"test_modules_to_save",
|
"test_modules_to_save",
|
||||||
|
|||||||
@@ -16,11 +16,6 @@ ERROR_CASES = [
|
|||||||
{"r": 1024},
|
{"r": 1024},
|
||||||
"is greater than max_lora_rank",
|
"is greater than max_lora_rank",
|
||||||
),
|
),
|
||||||
(
|
|
||||||
"test_bias",
|
|
||||||
{"bias": "all"},
|
|
||||||
"Adapter bias cannot be used without bias_enabled",
|
|
||||||
),
|
|
||||||
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
|
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
|
||||||
(
|
(
|
||||||
"test_modules_to_save",
|
"test_modules_to_save",
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ class LoRANameParserTestConfig(NamedTuple):
|
|||||||
name: str
|
name: str
|
||||||
module_name: str
|
module_name: str
|
||||||
is_lora_a: bool
|
is_lora_a: bool
|
||||||
is_bias: bool
|
|
||||||
weights_mapper: Optional[WeightsMapper] = None
|
weights_mapper: Optional[WeightsMapper] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -37,44 +36,37 @@ def test_parse_fine_tuned_lora_name_valid():
|
|||||||
"base_model.model.model.embed_tokens.lora_embedding_A",
|
"base_model.model.model.embed_tokens.lora_embedding_A",
|
||||||
"model.embed_tokens",
|
"model.embed_tokens",
|
||||||
True,
|
True,
|
||||||
False,
|
|
||||||
),
|
),
|
||||||
LoRANameParserTestConfig(
|
LoRANameParserTestConfig(
|
||||||
"base_model.model.model.embed_tokens.lora_embedding_B",
|
"base_model.model.model.embed_tokens.lora_embedding_B",
|
||||||
"model.embed_tokens",
|
"model.embed_tokens",
|
||||||
False,
|
False,
|
||||||
False,
|
|
||||||
),
|
),
|
||||||
LoRANameParserTestConfig(
|
LoRANameParserTestConfig(
|
||||||
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
"model.layers.9.mlp.down_proj",
|
"model.layers.9.mlp.down_proj",
|
||||||
True,
|
True,
|
||||||
False,
|
|
||||||
),
|
),
|
||||||
LoRANameParserTestConfig(
|
LoRANameParserTestConfig(
|
||||||
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
"model.layers.9.mlp.down_proj",
|
"model.layers.9.mlp.down_proj",
|
||||||
False,
|
False,
|
||||||
False,
|
|
||||||
),
|
),
|
||||||
LoRANameParserTestConfig(
|
LoRANameParserTestConfig(
|
||||||
"language_model.layers.9.mlp.down_proj.lora_A.weight",
|
"language_model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
"language_model.layers.9.mlp.down_proj",
|
"language_model.layers.9.mlp.down_proj",
|
||||||
True,
|
True,
|
||||||
False,
|
|
||||||
),
|
),
|
||||||
LoRANameParserTestConfig(
|
LoRANameParserTestConfig(
|
||||||
"language_model.layers.9.mlp.down_proj.lora_B.weight",
|
"language_model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
"language_model.layers.9.mlp.down_proj",
|
"language_model.layers.9.mlp.down_proj",
|
||||||
False,
|
False,
|
||||||
False,
|
|
||||||
),
|
),
|
||||||
# Test with WeightsMapper
|
# Test with WeightsMapper
|
||||||
LoRANameParserTestConfig(
|
LoRANameParserTestConfig(
|
||||||
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
"language_model.model.layers.9.mlp.down_proj",
|
"language_model.model.layers.9.mlp.down_proj",
|
||||||
True,
|
True,
|
||||||
False,
|
|
||||||
weights_mapper=WeightsMapper(
|
weights_mapper=WeightsMapper(
|
||||||
orig_to_new_prefix={"model.": "language_model.model."}
|
orig_to_new_prefix={"model.": "language_model.model."}
|
||||||
),
|
),
|
||||||
@@ -83,7 +75,6 @@ def test_parse_fine_tuned_lora_name_valid():
|
|||||||
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
"language_model.model.layers.9.mlp.down_proj",
|
"language_model.model.layers.9.mlp.down_proj",
|
||||||
False,
|
False,
|
||||||
False,
|
|
||||||
weights_mapper=WeightsMapper(
|
weights_mapper=WeightsMapper(
|
||||||
orig_to_new_prefix={"model.": "language_model.model."}
|
orig_to_new_prefix={"model.": "language_model.model."}
|
||||||
),
|
),
|
||||||
@@ -92,7 +83,6 @@ def test_parse_fine_tuned_lora_name_valid():
|
|||||||
"model.layers.9.mlp.down_proj.lora_A.weight",
|
"model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
"language_model.model.layers.9.mlp.down_proj",
|
"language_model.model.layers.9.mlp.down_proj",
|
||||||
True,
|
True,
|
||||||
False,
|
|
||||||
weights_mapper=WeightsMapper(
|
weights_mapper=WeightsMapper(
|
||||||
orig_to_new_prefix={"model.": "language_model.model."}
|
orig_to_new_prefix={"model.": "language_model.model."}
|
||||||
),
|
),
|
||||||
@@ -101,14 +91,13 @@ def test_parse_fine_tuned_lora_name_valid():
|
|||||||
"model.layers.9.mlp.down_proj.lora_B.weight",
|
"model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
"language_model.model.layers.9.mlp.down_proj",
|
"language_model.model.layers.9.mlp.down_proj",
|
||||||
False,
|
False,
|
||||||
False,
|
|
||||||
weights_mapper=WeightsMapper(
|
weights_mapper=WeightsMapper(
|
||||||
orig_to_new_prefix={"model.": "language_model.model."}
|
orig_to_new_prefix={"model.": "language_model.model."}
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
|
for name, module_name, is_lora_a, weights_mapper in fixture:
|
||||||
assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name(
|
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(
|
||||||
name, weights_mapper
|
name, weights_mapper
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -70,12 +70,6 @@ class LoRAConfig:
|
|||||||
per prompt. When run in offline mode, the lora IDs for n modalities
|
per prompt. When run in offline mode, the lora IDs for n modalities
|
||||||
will be automatically assigned to 1-n with the names of the modalities
|
will be automatically assigned to 1-n with the names of the modalities
|
||||||
in alphabetic order."""
|
in alphabetic order."""
|
||||||
bias_enabled: bool = Field(
|
|
||||||
default=False,
|
|
||||||
deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
|
|
||||||
)
|
|
||||||
"""[DEPRECATED] Enable bias for LoRA adapters. This option will be
|
|
||||||
removed in v0.12.0."""
|
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -96,7 +90,7 @@ class LoRAConfig:
|
|||||||
factors.append(self.lora_dtype)
|
factors.append(self.lora_dtype)
|
||||||
factors.append(self.lora_extra_vocab_size)
|
factors.append(self.lora_extra_vocab_size)
|
||||||
factors.append(self.lora_vocab_padding_size)
|
factors.append(self.lora_vocab_padding_size)
|
||||||
factors.append(self.bias_enabled)
|
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
|
|||||||
@@ -439,7 +439,6 @@ class EngineArgs:
|
|||||||
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
|
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
|
||||||
# LoRA fields
|
# LoRA fields
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
|
||||||
max_loras: int = LoRAConfig.max_loras
|
max_loras: int = LoRAConfig.max_loras
|
||||||
max_lora_rank: int = LoRAConfig.max_lora_rank
|
max_lora_rank: int = LoRAConfig.max_lora_rank
|
||||||
default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
|
default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
|
||||||
@@ -916,7 +915,6 @@ class EngineArgs:
|
|||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
help="If True, enable handling of LoRA adapters.",
|
help="If True, enable handling of LoRA adapters.",
|
||||||
)
|
)
|
||||||
lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"])
|
|
||||||
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
|
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
|
||||||
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
|
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
|
||||||
lora_group.add_argument(
|
lora_group.add_argument(
|
||||||
@@ -1515,7 +1513,6 @@ class EngineArgs:
|
|||||||
|
|
||||||
lora_config = (
|
lora_config = (
|
||||||
LoRAConfig(
|
LoRAConfig(
|
||||||
bias_enabled=self.enable_lora_bias,
|
|
||||||
max_lora_rank=self.max_lora_rank,
|
max_lora_rank=self.max_lora_rank,
|
||||||
max_loras=self.max_loras,
|
max_loras=self.max_loras,
|
||||||
default_mm_loras=self.default_mm_loras,
|
default_mm_loras=self.default_mm_loras,
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ class BaseLayerWithLoRA(nn.Module):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
"""Overwrites lora tensors at index."""
|
"""Overwrites lora tensors at index."""
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional, cast
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@@ -29,7 +29,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.tp_size = self.base_layer.tp_size
|
self.tp_size = self.base_layer.tp_size
|
||||||
self.tp_rank = self.base_layer.tp_rank
|
self.tp_rank = self.base_layer.tp_rank
|
||||||
self.device = _get_lora_device(self.base_layer)
|
self.device = _get_lora_device(self.base_layer)
|
||||||
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
|
|
||||||
self.output_slices: tuple[int, ...]
|
self.output_slices: tuple[int, ...]
|
||||||
self.output_size: int
|
self.output_size: int
|
||||||
self.n_slices: int
|
self.n_slices: int
|
||||||
@@ -86,30 +85,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
)
|
)
|
||||||
for _ in range(self.n_slices)
|
for _ in range(self.n_slices)
|
||||||
)
|
)
|
||||||
if lora_config.bias_enabled:
|
|
||||||
lora_bias_out_size = lora_b_out_size
|
|
||||||
self.lora_bias_stacked = tuple(
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
lora_bias_out_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
for _ in range(self.n_slices)
|
|
||||||
)
|
|
||||||
self.output_slices = (self.lora_b_stacked[0].shape[2],)
|
self.output_slices = (self.lora_b_stacked[0].shape[2],)
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
for s_index in range(self.n_slices):
|
for s_index in range(self.n_slices):
|
||||||
self.lora_a_stacked[s_index][index] = 0
|
self.lora_a_stacked[s_index][index] = 0
|
||||||
self.lora_b_stacked[s_index][index] = 0
|
self.lora_b_stacked[s_index][index] = 0
|
||||||
if self.lora_config.bias_enabled:
|
|
||||||
# Make mypy happy
|
|
||||||
self.lora_bias_stacked = cast(
|
|
||||||
tuple[torch.Tensor, ...], self.lora_bias_stacked
|
|
||||||
)
|
|
||||||
self.lora_bias_stacked[s_index][index] = 0
|
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
@@ -117,7 +98,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
lora_bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
# Except for QKVParallelLinearWithLoRA and
|
# Except for QKVParallelLinearWithLoRA and
|
||||||
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
||||||
@@ -131,8 +111,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
if lora_bias is not None:
|
|
||||||
lora_bias = self.slice_bias(lora_bias)
|
|
||||||
|
|
||||||
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
||||||
lora_a, non_blocking=True
|
lora_a, non_blocking=True
|
||||||
@@ -140,14 +118,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||||
lora_b, non_blocking=True
|
lora_b, non_blocking=True
|
||||||
)
|
)
|
||||||
if lora_bias is not None:
|
|
||||||
self.lora_bias_stacked = cast(
|
|
||||||
tuple[torch.Tensor, ...], self.lora_bias_stacked
|
|
||||||
)
|
|
||||||
assert len(self.lora_bias_stacked)
|
|
||||||
self.lora_bias_stacked[0][index, 0, : lora_bias.shape[0]].copy_(
|
|
||||||
lora_bias, non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||||
@@ -162,13 +132,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
x = x.flatten(0, 1)
|
x = x.flatten(0, 1)
|
||||||
|
|
||||||
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear(
|
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear(
|
||||||
output,
|
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
|
||||||
x,
|
|
||||||
self.lora_a_stacked,
|
|
||||||
self.lora_b_stacked,
|
|
||||||
self.lora_bias_stacked,
|
|
||||||
1.0,
|
|
||||||
self.output_slices,
|
|
||||||
)
|
)
|
||||||
if not current_platform.can_update_inplace():
|
if not current_platform.can_update_inplace():
|
||||||
output = lora_output
|
output = lora_output
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -32,8 +32,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
|
|||||||
== len(layer.lora_b_stacked)
|
== len(layer.lora_b_stacked)
|
||||||
== len(layer.output_slices)
|
== len(layer.output_slices)
|
||||||
)
|
)
|
||||||
if layer.lora_bias_stacked is not None:
|
|
||||||
assert layer.n_slices == len(layer.lora_bias_stacked)
|
|
||||||
|
|
||||||
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
||||||
|
|
||||||
@@ -61,7 +59,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
|
|||||||
output,
|
output,
|
||||||
buffers,
|
buffers,
|
||||||
layer.lora_b_stacked,
|
layer.lora_b_stacked,
|
||||||
layer.lora_bias_stacked,
|
|
||||||
layer.output_slices,
|
layer.output_slices,
|
||||||
offset_start=0,
|
offset_start=0,
|
||||||
add_input=True,
|
add_input=True,
|
||||||
@@ -122,16 +119,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
lora_b = lora_b[start_idx:end_idx, :]
|
lora_b = lora_b[start_idx:end_idx, :]
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
|
||||||
# TODO: Fix the slicing logic of bias.
|
|
||||||
if bias is None:
|
|
||||||
return bias
|
|
||||||
shard_size = self.output_size
|
|
||||||
start_idx = self.tp_rank * shard_size
|
|
||||||
end_idx = (self.tp_rank + 1) * shard_size
|
|
||||||
bias = bias[start_idx:end_idx]
|
|
||||||
return bias
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||||
@@ -238,17 +225,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
)
|
)
|
||||||
for output_size in self.output_slices
|
for output_size in self.output_slices
|
||||||
)
|
)
|
||||||
if lora_config.bias_enabled:
|
|
||||||
self.lora_bias_stacked = tuple(
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
output_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
for output_size in self.output_slices
|
|
||||||
)
|
|
||||||
|
|
||||||
def slice_lora_a(
|
def slice_lora_a(
|
||||||
self, lora_a: list[Union[torch.Tensor, None]]
|
self, lora_a: list[Union[torch.Tensor, None]]
|
||||||
@@ -268,31 +244,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
]
|
]
|
||||||
return sliced_lora_b
|
return sliced_lora_b
|
||||||
|
|
||||||
def slice_bias(
|
|
||||||
self, bias: list[Union[torch.Tensor, None]]
|
|
||||||
) -> list[Union[torch.Tensor, None]]:
|
|
||||||
for i, (shard_id, shard_size) in enumerate(
|
|
||||||
zip(self.output_ids, self.output_slices)
|
|
||||||
):
|
|
||||||
if (bias_i := bias[i]) is not None:
|
|
||||||
bias[i] = bias_i[shard_size * shard_id : shard_size * (shard_id + 1)]
|
|
||||||
return bias
|
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
lora_bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
if lora_bias is not None:
|
|
||||||
lora_bias = self.slice_bias(lora_bias)
|
|
||||||
|
|
||||||
for i in range(self.n_slices):
|
for i in range(self.n_slices):
|
||||||
if (lora_a_i := lora_a[i]) is not None:
|
if (lora_a_i := lora_a[i]) is not None:
|
||||||
@@ -304,16 +267,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
|
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
|
||||||
].copy_(lora_b_i, non_blocking=True)
|
].copy_(lora_b_i, non_blocking=True)
|
||||||
|
|
||||||
if lora_bias is not None:
|
|
||||||
self.lora_bias_stacked = cast(
|
|
||||||
tuple[torch.Tensor, ...], self.lora_bias_stacked
|
|
||||||
)
|
|
||||||
for i in range(self.n_slices):
|
|
||||||
if (lora_bias_i := lora_bias[i]) is not None:
|
|
||||||
self.lora_bias_stacked[i][index, 0, : lora_bias_i.shape[0]].copy_(
|
|
||||||
lora_bias_i, non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
@@ -380,24 +333,6 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
|
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
|
||||||
bias_q = bias[
|
|
||||||
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
|
|
||||||
* (self.q_shard_id + 1)
|
|
||||||
]
|
|
||||||
k_offset = self.q_proj_total_size
|
|
||||||
bias_k = bias[
|
|
||||||
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
|
|
||||||
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
|
|
||||||
]
|
|
||||||
v_offset = k_offset + self.kv_proj_total_size
|
|
||||||
bias_v = bias[
|
|
||||||
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
|
|
||||||
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
|
|
||||||
]
|
|
||||||
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
|
|
||||||
return bias
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
|
|||||||
@@ -143,7 +143,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -39,9 +39,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
|
||||||
return bias
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||||
@@ -123,16 +120,6 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
lora_b = lora_b[start_idx:end_idx, :]
|
lora_b = lora_b[start_idx:end_idx, :]
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
|
||||||
if bias is None:
|
|
||||||
return bias
|
|
||||||
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked)
|
|
||||||
shard_size = self.lora_bias_stacked[0].shape[2]
|
|
||||||
start_idx = self.tp_rank * shard_size
|
|
||||||
end_idx = (self.tp_rank + 1) * shard_size
|
|
||||||
bias = bias[start_idx:end_idx]
|
|
||||||
return bias
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -167,7 +154,6 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
output,
|
output,
|
||||||
buffer,
|
buffer,
|
||||||
self.lora_b_stacked,
|
self.lora_b_stacked,
|
||||||
self.lora_bias_stacked,
|
|
||||||
self.output_slices,
|
self.output_slices,
|
||||||
offset_start=offset_start,
|
offset_start=offset_start,
|
||||||
add_input=True,
|
add_input=True,
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ class LoRALayerWeights:
|
|||||||
lora_alpha: int,
|
lora_alpha: int,
|
||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||||
scaling: Optional[float] = None,
|
scaling: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -30,7 +29,6 @@ class LoRALayerWeights:
|
|||||||
self.lora_alpha = lora_alpha
|
self.lora_alpha = lora_alpha
|
||||||
self.lora_a = lora_a
|
self.lora_a = lora_a
|
||||||
self.lora_b = lora_b
|
self.lora_b = lora_b
|
||||||
self.bias = bias
|
|
||||||
self.embeddings_tensor = embeddings_tensor
|
self.embeddings_tensor = embeddings_tensor
|
||||||
|
|
||||||
if scaling is None:
|
if scaling is None:
|
||||||
@@ -71,13 +69,13 @@ class LoRALayerWeights:
|
|||||||
peft_helper: PEFTHelper,
|
peft_helper: PEFTHelper,
|
||||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||||
) -> "LoRALayerWeights":
|
) -> "LoRALayerWeights":
|
||||||
|
# lora_a and lora_b are set to None for config-based construction
|
||||||
return cls(
|
return cls(
|
||||||
module_name,
|
module_name,
|
||||||
peft_helper.r,
|
peft_helper.r,
|
||||||
peft_helper.lora_alpha,
|
peft_helper.lora_alpha,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
embeddings_tensor,
|
embeddings_tensor,
|
||||||
peft_helper.vllm_lora_scaling_factor,
|
peft_helper.vllm_lora_scaling_factor,
|
||||||
)
|
)
|
||||||
@@ -92,7 +90,6 @@ class LoRALayerWeights:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.types.Device,
|
device: torch.types.Device,
|
||||||
embeddings_tensor_dim: Optional[int] = None,
|
embeddings_tensor_dim: Optional[int] = None,
|
||||||
bias_enabled: Optional[bool] = False,
|
|
||||||
) -> "LoRALayerWeights":
|
) -> "LoRALayerWeights":
|
||||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
lora_a = torch.zeros(
|
lora_a = torch.zeros(
|
||||||
@@ -101,12 +98,6 @@ class LoRALayerWeights:
|
|||||||
lora_b = torch.zeros(
|
lora_b = torch.zeros(
|
||||||
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
|
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
|
||||||
)
|
)
|
||||||
if bias_enabled:
|
|
||||||
bias = torch.zeros(
|
|
||||||
[output_dim], dtype=dtype, device=device, pin_memory=pin_memory
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
bias = None
|
|
||||||
|
|
||||||
embeddings_tensor = (
|
embeddings_tensor = (
|
||||||
torch.rand(
|
torch.rand(
|
||||||
@@ -125,7 +116,6 @@ class LoRALayerWeights:
|
|||||||
lora_alpha=1,
|
lora_alpha=1,
|
||||||
lora_a=lora_a,
|
lora_a=lora_a,
|
||||||
lora_b=lora_b,
|
lora_b=lora_b,
|
||||||
bias=bias,
|
|
||||||
embeddings_tensor=embeddings_tensor,
|
embeddings_tensor=embeddings_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -140,7 +130,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
lora_alphas: list[Optional[int]],
|
lora_alphas: list[Optional[int]],
|
||||||
lora_a: list[Optional[torch.Tensor]],
|
lora_a: list[Optional[torch.Tensor]],
|
||||||
lora_b: list[Optional[torch.Tensor]],
|
lora_b: list[Optional[torch.Tensor]],
|
||||||
bias: Optional[list[Optional[torch.Tensor]]] = None,
|
|
||||||
scaling: Optional[list[float]] = None,
|
scaling: Optional[list[float]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -149,7 +138,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
lora_alpha=0,
|
lora_alpha=0,
|
||||||
lora_a=lora_a,
|
lora_a=lora_a,
|
||||||
lora_b=lora_b,
|
lora_b=lora_b,
|
||||||
bias=bias,
|
|
||||||
scaling=scaling, # type: ignore
|
scaling=scaling, # type: ignore
|
||||||
embeddings_tensor=None,
|
embeddings_tensor=None,
|
||||||
)
|
)
|
||||||
@@ -181,7 +169,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
[lora.lora_alpha if lora is not None else None for lora in loras],
|
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||||
[lora.lora_a if lora is not None else None for lora in loras],
|
[lora.lora_a if lora is not None else None for lora in loras],
|
||||||
[lora.lora_b if lora is not None else None for lora in loras],
|
[lora.lora_b if lora is not None else None for lora in loras],
|
||||||
[lora.bias if lora is not None else None for lora in loras],
|
|
||||||
scaling=[
|
scaling=[
|
||||||
1 if lora is not None else None # type: ignore
|
1 if lora is not None else None # type: ignore
|
||||||
for lora in loras
|
for lora in loras
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Callable, Optional, TypeVar, Union
|
from typing import Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@@ -140,7 +139,7 @@ class LoRAModel:
|
|||||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
loras: dict[str, LoRALayerWeights] = {}
|
loras: dict[str, LoRALayerWeights] = {}
|
||||||
for tensor_name, tensor in tensors.items():
|
for tensor_name, tensor in tensors.items():
|
||||||
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
|
module_name, is_lora_a = parse_fine_tuned_lora_name(
|
||||||
tensor_name, weights_mapper
|
tensor_name, weights_mapper
|
||||||
)
|
)
|
||||||
if module_name not in loras:
|
if module_name not in loras:
|
||||||
@@ -160,13 +159,7 @@ class LoRAModel:
|
|||||||
module_name, peft_helper, lora_embeddings_tensor
|
module_name, peft_helper, lora_embeddings_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_bias:
|
if is_lora_a:
|
||||||
loras[module_name].bias = tensor.to(device=device, dtype=dtype)
|
|
||||||
bias = tensor.to(device=device, dtype=dtype)
|
|
||||||
if pin_memory:
|
|
||||||
bias = bias.pin_memory()
|
|
||||||
loras[module_name].bias = bias
|
|
||||||
elif is_lora_a:
|
|
||||||
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
|
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
|
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
|
||||||
@@ -234,9 +227,7 @@ class LoRAModel:
|
|||||||
|
|
||||||
def check_unexpected_modules(modules: dict):
|
def check_unexpected_modules(modules: dict):
|
||||||
for lora_module in modules.keys(): # noqa
|
for lora_module in modules.keys(): # noqa
|
||||||
module_name, _, _ = parse_fine_tuned_lora_name(
|
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
|
||||||
lora_module, weights_mapper
|
|
||||||
)
|
|
||||||
part_name = module_name.split(".")[-1]
|
part_name = module_name.split(".")[-1]
|
||||||
if part_name not in expected_lora_modules:
|
if part_name not in expected_lora_modules:
|
||||||
unexpected_modules.append(module_name)
|
unexpected_modules.append(module_name)
|
||||||
@@ -439,23 +430,11 @@ class LoRAModelManager:
|
|||||||
module_lora = self._get_lora_layer_weights(lora_model, module_name)
|
module_lora = self._get_lora_layer_weights(lora_model, module_name)
|
||||||
if module_lora:
|
if module_lora:
|
||||||
module_lora.optimize()
|
module_lora.optimize()
|
||||||
# Bias is not explicitly enabled with the flag enable_lora_bias.
|
|
||||||
bias = module_lora.bias
|
|
||||||
if (
|
|
||||||
torch.is_tensor(bias)
|
|
||||||
or (isinstance(bias, Sequence) and any(b is not None for b in bias))
|
|
||||||
) and not self.lora_config.bias_enabled:
|
|
||||||
module_lora.bias = None
|
|
||||||
raise ValueError(
|
|
||||||
f"Adapter bias cannot be used for {module_name}"
|
|
||||||
" without --enable-lora-bias."
|
|
||||||
)
|
|
||||||
module.set_lora(
|
module.set_lora(
|
||||||
index,
|
index,
|
||||||
module_lora.lora_a,
|
module_lora.lora_a,
|
||||||
module_lora.lora_b,
|
module_lora.lora_b,
|
||||||
module_lora.embeddings_tensor,
|
module_lora.embeddings_tensor,
|
||||||
module_lora.bias,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
module.reset_lora(index)
|
module.reset_lora(index)
|
||||||
@@ -581,7 +560,6 @@ class LoRAModelManager:
|
|||||||
"""Create zero-initialized LoRAModel for warmup."""
|
"""Create zero-initialized LoRAModel for warmup."""
|
||||||
model = LoRAModel(lora_id, rank, {})
|
model = LoRAModel(lora_id, rank, {})
|
||||||
for module_name, module in self.model.named_modules():
|
for module_name, module in self.model.named_modules():
|
||||||
bias_enabled = self.lora_config.bias_enabled
|
|
||||||
if (
|
if (
|
||||||
not self._match_target_modules(module_name)
|
not self._match_target_modules(module_name)
|
||||||
or not isinstance(module, BaseLayerWithLoRA)
|
or not isinstance(module, BaseLayerWithLoRA)
|
||||||
@@ -616,7 +594,6 @@ class LoRAModelManager:
|
|||||||
module.lora_a_stacked[0].dtype,
|
module.lora_a_stacked[0].dtype,
|
||||||
"cpu",
|
"cpu",
|
||||||
embeddings_tensor_dim=embeddings_tensor_dim,
|
embeddings_tensor_dim=embeddings_tensor_dim,
|
||||||
bias_enabled=bias_enabled,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||||
@@ -626,7 +603,6 @@ class LoRAModelManager:
|
|||||||
rank,
|
rank,
|
||||||
module.lora_a_stacked[0].dtype,
|
module.lora_a_stacked[0].dtype,
|
||||||
"cpu",
|
"cpu",
|
||||||
bias_enabled=bias_enabled,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
@@ -640,7 +616,6 @@ class LoRAModelManager:
|
|||||||
rank,
|
rank,
|
||||||
module.lora_a_stacked[i].dtype,
|
module.lora_a_stacked[i].dtype,
|
||||||
"cpu",
|
"cpu",
|
||||||
bias_enabled=bias_enabled,
|
|
||||||
)
|
)
|
||||||
subloras.append(lora)
|
subloras.append(lora)
|
||||||
lora = PackedLoRALayerWeights.pack(subloras)
|
lora = PackedLoRALayerWeights.pack(subloras)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class PEFTHelper:
|
|||||||
lora_alpha: int
|
lora_alpha: int
|
||||||
target_modules: Union[list[str], str]
|
target_modules: Union[list[str], str]
|
||||||
|
|
||||||
bias: Literal["none", "all", "lora_only"] = field(default="none")
|
bias: Literal["none"] = field(default="none")
|
||||||
modules_to_save: Optional[list[str]] = field(default=None)
|
modules_to_save: Optional[list[str]] = field(default=None)
|
||||||
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
|
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
|
||||||
use_rslora: bool = field(default=False)
|
use_rslora: bool = field(default=False)
|
||||||
@@ -122,7 +122,7 @@ class PEFTHelper:
|
|||||||
f"LoRA rank {self.r} is greater than max_lora_rank"
|
f"LoRA rank {self.r} is greater than max_lora_rank"
|
||||||
f" {lora_config.max_lora_rank}."
|
f" {lora_config.max_lora_rank}."
|
||||||
)
|
)
|
||||||
if self.bias != "none" and not lora_config.bias_enabled:
|
if self.bias != "none":
|
||||||
error_msg.append("Adapter bias cannot be used without bias_enabled.")
|
error_msg.append("Adapter bias is not supported.")
|
||||||
if error_msg:
|
if error_msg:
|
||||||
raise ValueError(f"{' '.join(error_msg)}")
|
raise ValueError(f"{' '.join(error_msg)}")
|
||||||
|
|||||||
@@ -60,14 +60,13 @@ class PunicaWrapperABC(ABC):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM for multiple slices of lora_b.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -93,7 +92,6 @@ class PunicaWrapperABC(ABC):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
scale: float,
|
scale: float,
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
*,
|
*,
|
||||||
@@ -222,38 +220,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
self.token_nums = token_nums
|
self.token_nums = token_nums
|
||||||
self.no_lora = no_lora
|
self.no_lora = no_lora
|
||||||
|
|
||||||
def _apply_bias(
|
|
||||||
self,
|
|
||||||
indices: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
output_slices: tuple[int, ...],
|
|
||||||
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
|
|
||||||
):
|
|
||||||
"""Applies bias to output
|
|
||||||
|
|
||||||
Input shapes:
|
|
||||||
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
|
|
||||||
indices: (batch_size)
|
|
||||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
|
||||||
output_slices: n-1 element tuple of (slice_size...),
|
|
||||||
where n is number of slices
|
|
||||||
"""
|
|
||||||
org_output = output
|
|
||||||
output = output.view(-1, output.shape[-1])
|
|
||||||
indices = indices.view(-1)
|
|
||||||
|
|
||||||
offset_left = 0
|
|
||||||
for slice_idx, slice in enumerate(output_slices):
|
|
||||||
bias = lora_bias_stacked[slice_idx]
|
|
||||||
if bias is not None:
|
|
||||||
bias = bias.view(-1, bias.shape[-1])
|
|
||||||
bias = bias[indices]
|
|
||||||
bias[indices == -1] = 0
|
|
||||||
output[:, offset_left : offset_left + slice] += bias
|
|
||||||
offset_left += slice
|
|
||||||
|
|
||||||
return output.view_as(org_output)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(
|
def prefill_metadata(
|
||||||
self,
|
self,
|
||||||
@@ -365,29 +331,25 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM for multiple slices of lora_b.
|
||||||
|
|
||||||
Semantics:
|
Semantics:
|
||||||
offset = offset_start
|
offset = offset_start
|
||||||
for i in range(len(lora_b_stacked)):
|
for i in range(len(lora_b_stacked)):
|
||||||
slice = output_slices[i]
|
slice = output_slices[i]
|
||||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||||
lora_bias_stacked[i]
|
|
||||||
offset += slice
|
offset += slice
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor.
|
y (torch.Tensor): Output tensor.
|
||||||
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
|
|
||||||
bias's weight
|
|
||||||
output_slices (tuple[int, ...]): Every slice's size
|
output_slices (tuple[int, ...]): Every slice's size
|
||||||
offset_start (int): The starting position of y, defaults to 0
|
offset_start (int): The starting position of y, defaults to 0
|
||||||
add_inputs (bool): Defaults to True.
|
add_inputs (bool): Defaults to True.
|
||||||
@@ -427,7 +389,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
scale: float,
|
scale: float,
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
*,
|
*,
|
||||||
@@ -444,14 +405,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||||
* scale
|
* scale
|
||||||
).squeeze(0)+lora_bias_stacked[i]
|
).squeeze(0)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||||
x (torch.Tensor): Input tensor
|
x (torch.Tensor): Input tensor
|
||||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
|
|
||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
output_slices (tuple[int, ...]): Every slice's size.
|
output_slices (tuple[int, ...]): Every slice's size.
|
||||||
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
||||||
|
|||||||
@@ -199,38 +199,30 @@ class PunicaWrapperCPU(PunicaWrapperBase):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM for multiple slices of lora_b.
|
||||||
|
|
||||||
Semantics:
|
Semantics:
|
||||||
for i in range(len(lora_b_stacked)):
|
for i in range(len(lora_b_stacked)):
|
||||||
slice = output_slices[i]
|
slice = output_slices[i]
|
||||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||||
lora_bias_stacked[i]
|
|
||||||
offset += slice
|
offset += slice
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor.
|
y (torch.Tensor): Output tensor.
|
||||||
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
|
|
||||||
bias's weight
|
|
||||||
output_slices (tuple[int, ...]): Every slice's size
|
output_slices (tuple[int, ...]): Every slice's size
|
||||||
add_inputs (bool): Defaults to True.
|
add_inputs (bool): Defaults to True.
|
||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
offset_left = offset_start
|
offset_left = offset_start
|
||||||
if lora_bias_stacked is not None:
|
|
||||||
self._apply_bias(
|
|
||||||
self.token_lora_indices, y, output_slices, lora_bias_stacked
|
|
||||||
)
|
|
||||||
for slice_idx in range(len(lora_b_stacked)):
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
self._apply_expand(
|
self._apply_expand(
|
||||||
y,
|
y,
|
||||||
@@ -276,7 +268,6 @@ class PunicaWrapperCPU(PunicaWrapperBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
scale: float,
|
scale: float,
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
*,
|
*,
|
||||||
@@ -293,25 +284,19 @@ class PunicaWrapperCPU(PunicaWrapperBase):
|
|||||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||||
* scale
|
* scale
|
||||||
).squeeze(0)+lora_bias_stacked[i]
|
).squeeze(0)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||||
x (torch.Tensor): Input tensor
|
x (torch.Tensor): Input tensor
|
||||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
|
|
||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
output_slices (tuple[int, ...]): Every slice's size.
|
output_slices (tuple[int, ...]): Every slice's size.
|
||||||
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
if lora_bias_stacked is not None:
|
|
||||||
assert len(lora_bias_stacked) == len(output_slices)
|
|
||||||
y = self._apply_bias(
|
|
||||||
self.token_lora_indices, y, output_slices, lora_bias_stacked
|
|
||||||
)
|
|
||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
r = lora_b_stacked[0].size(-1)
|
r = lora_b_stacked[0].size(-1)
|
||||||
@@ -323,7 +308,7 @@ class PunicaWrapperCPU(PunicaWrapperBase):
|
|||||||
)
|
)
|
||||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||||
self.add_expand(
|
self.add_expand(
|
||||||
y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs
|
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_lora_logits(
|
def add_lora_logits(
|
||||||
|
|||||||
@@ -101,36 +101,29 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM for multiple slices of lora_b.
|
||||||
|
|
||||||
Semantics:
|
Semantics:
|
||||||
for i in range(len(lora_b_stacked)):
|
for i in range(len(lora_b_stacked)):
|
||||||
slice = output_slices[i]
|
slice = output_slices[i]
|
||||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||||
lora_bias_stacked[i]
|
|
||||||
offset += slice
|
offset += slice
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor.
|
y (torch.Tensor): Output tensor.
|
||||||
x (torch.Tensor): Input tensors
|
x (torch.Tensor): Input tensors
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
|
|
||||||
bias's weight
|
|
||||||
output_slices (tuple[int, ...]): Every slice's size
|
output_slices (tuple[int, ...]): Every slice's size
|
||||||
add_inputs (bool): Defaults to True.
|
add_inputs (bool): Defaults to True.
|
||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
if lora_bias_stacked is not None:
|
|
||||||
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0))
|
|
||||||
self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked)
|
|
||||||
|
|
||||||
assert x.ndim == 3
|
assert x.ndim == 3
|
||||||
assert x.size(0) == len(output_slices)
|
assert x.size(0) == len(output_slices)
|
||||||
@@ -183,7 +176,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
scale: float,
|
scale: float,
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
*,
|
*,
|
||||||
@@ -200,26 +192,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||||
* scale
|
* scale
|
||||||
).squeeze(0)+lora_bias_stacked[i]
|
).squeeze(0)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||||
x (torch.Tensor): Input tensor
|
x (torch.Tensor): Input tensor
|
||||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
|
|
||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
output_slices (tuple[int, ...]): Every slice's size.
|
output_slices (tuple[int, ...]): Every slice's size.
|
||||||
buffer (Optional[torch.Tensor]): Defaults to None.
|
buffer (Optional[torch.Tensor]): Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
if lora_bias_stacked is not None:
|
|
||||||
assert len(lora_bias_stacked) == len(output_slices)
|
|
||||||
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0))
|
|
||||||
y = self._apply_bias(
|
|
||||||
token_lora_indices, y, output_slices, lora_bias_stacked
|
|
||||||
)
|
|
||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
r = lora_b_stacked[0].size(-1)
|
r = lora_b_stacked[0].size(-1)
|
||||||
@@ -241,7 +225,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
y,
|
y,
|
||||||
buffer, # type: ignore
|
buffer, # type: ignore
|
||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
None,
|
|
||||||
output_slices,
|
output_slices,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -139,28 +139,24 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM for multiple slices of lora_b.
|
||||||
|
|
||||||
Semantics:
|
Semantics:
|
||||||
for i in range(len(lora_b_stacked)):
|
for i in range(len(lora_b_stacked)):
|
||||||
slice = output_slices[i]
|
slice = output_slices[i]
|
||||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||||
lora_bias_stacked[i]
|
|
||||||
offset += slice
|
offset += slice
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor.
|
y (torch.Tensor): Output tensor.
|
||||||
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
|
|
||||||
bias's weight
|
|
||||||
output_slices (tuple[int, ...]): Every slice's size
|
output_slices (tuple[int, ...]): Every slice's size
|
||||||
add_inputs (bool): Defaults to True.
|
add_inputs (bool): Defaults to True.
|
||||||
"""
|
"""
|
||||||
@@ -168,10 +164,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
offset_left = 0
|
offset_left = 0
|
||||||
|
|
||||||
if lora_bias_stacked is not None:
|
|
||||||
y = self._apply_bias(
|
|
||||||
self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked
|
|
||||||
)
|
|
||||||
for slice_idx in range(len(lora_b_stacked)):
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
y = self.expand_slice(
|
y = self.expand_slice(
|
||||||
y,
|
y,
|
||||||
@@ -214,7 +206,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
scale: float,
|
scale: float,
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
*,
|
*,
|
||||||
@@ -231,25 +222,19 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||||
* scale
|
* scale
|
||||||
).squeeze(0)+lora_bias_stacked[i]
|
).squeeze(0)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor. Will not be changed in-place.
|
y (torch.Tensor): Output tensor. Will not be changed in-place.
|
||||||
x (torch.Tensor): Input tensor (T, E)
|
x (torch.Tensor): Input tensor (T, E)
|
||||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
|
|
||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
output_slices (tuple[int, ...]): Every slice's size.
|
output_slices (tuple[int, ...]): Every slice's size.
|
||||||
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
if lora_bias_stacked is not None:
|
|
||||||
assert len(lora_bias_stacked) == len(output_slices)
|
|
||||||
y = self._apply_bias(
|
|
||||||
self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked
|
|
||||||
)
|
|
||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
r = lora_b_stacked[0].size(-1)
|
r = lora_b_stacked[0].size(-1)
|
||||||
@@ -261,7 +246,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
)
|
)
|
||||||
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||||
return self.add_expand(
|
return self.add_expand(
|
||||||
y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs
|
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_lora_logits(
|
def add_lora_logits(
|
||||||
@@ -299,43 +284,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
|
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
|
||||||
return y.view_as(y_org)
|
return y.view_as(y_org)
|
||||||
|
|
||||||
def _apply_bias(
|
|
||||||
self,
|
|
||||||
indices: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
output_slices: tuple[int, ...],
|
|
||||||
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
|
|
||||||
):
|
|
||||||
"""Applies bias to output
|
|
||||||
|
|
||||||
Input shapes:
|
|
||||||
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
|
|
||||||
indices: (batch_size)
|
|
||||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
|
||||||
output_slices: n-1 element tuple of (slice_size...),
|
|
||||||
where n is number of slices
|
|
||||||
"""
|
|
||||||
org_output = output
|
|
||||||
output = output.view(-1, output.shape[-1])
|
|
||||||
indices = indices.view(-1)
|
|
||||||
|
|
||||||
offset_left = 0
|
|
||||||
for slice_idx, slice in enumerate(output_slices):
|
|
||||||
bias = lora_bias_stacked[slice_idx]
|
|
||||||
if bias is not None:
|
|
||||||
bias = bias.view(-1, bias.shape[-1])
|
|
||||||
bias = bias[indices]
|
|
||||||
bias = torch.where(indices[:, None] == -1, 0, bias)
|
|
||||||
|
|
||||||
bias = F.pad(
|
|
||||||
bias, (offset_left, output.shape[1] - (offset_left + slice), 0, 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
output += bias
|
|
||||||
offset_left += slice
|
|
||||||
|
|
||||||
return output.view_as(org_output)
|
|
||||||
|
|
||||||
# This performs the same tensor ops as the base method, except it does them
|
# This performs the same tensor ops as the base method, except it does them
|
||||||
# on the CPU then transfers the results to the TPU
|
# on the CPU then transfers the results to the TPU
|
||||||
def _update_base_metadata(
|
def _update_base_metadata(
|
||||||
|
|||||||
@@ -108,36 +108,29 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM for multiple slices of lora_b.
|
||||||
|
|
||||||
Semantics:
|
Semantics:
|
||||||
for i in range(len(lora_b_stacked)):
|
for i in range(len(lora_b_stacked)):
|
||||||
slice = output_slices[i]
|
slice = output_slices[i]
|
||||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||||
lora_bias_stacked[i]
|
|
||||||
offset += slice
|
offset += slice
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor.
|
y (torch.Tensor): Output tensor.
|
||||||
x (torch.Tensor): Input tensors
|
x (torch.Tensor): Input tensors
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
|
|
||||||
bias's weight
|
|
||||||
output_slices (tuple[int, ...]): Every slice's size
|
output_slices (tuple[int, ...]): Every slice's size
|
||||||
add_inputs (bool): Defaults to True.
|
add_inputs (bool): Defaults to True.
|
||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
if lora_bias_stacked is not None:
|
|
||||||
token_lora_indices = self._get_token_lora_indices(y)
|
|
||||||
self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked)
|
|
||||||
|
|
||||||
assert x.ndim == 3
|
assert x.ndim == 3
|
||||||
assert x.size(0) == len(output_slices)
|
assert x.size(0) == len(output_slices)
|
||||||
@@ -184,7 +177,6 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
|
|
||||||
scale: float,
|
scale: float,
|
||||||
output_slices: tuple[int, ...],
|
output_slices: tuple[int, ...],
|
||||||
*,
|
*,
|
||||||
@@ -201,26 +193,19 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
|||||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||||
* scale
|
* scale
|
||||||
).squeeze(0)+lora_bias_stacked[i]
|
).squeeze(0)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||||
x (torch.Tensor): Input tensor
|
x (torch.Tensor): Input tensor
|
||||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||||
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
|
|
||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
output_slices (tuple[int, ...]): Every slice's size.
|
output_slices (tuple[int, ...]): Every slice's size.
|
||||||
buffer (Optional[torch.Tensor]): Defaults to None.
|
buffer (Optional[torch.Tensor]): Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
if lora_bias_stacked is not None:
|
|
||||||
assert len(lora_bias_stacked) == len(output_slices)
|
|
||||||
token_lora_indices = self._get_token_lora_indices(y)
|
|
||||||
y = self._apply_bias(
|
|
||||||
token_lora_indices, y, output_slices, lora_bias_stacked
|
|
||||||
)
|
|
||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
r = lora_b_stacked[0].size(-1)
|
r = lora_b_stacked[0].size(-1)
|
||||||
@@ -242,7 +227,6 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
|||||||
y,
|
y,
|
||||||
buffer, # type: ignore
|
buffer, # type: ignore
|
||||||
lora_b_stacked,
|
lora_b_stacked,
|
||||||
None,
|
|
||||||
output_slices,
|
output_slices,
|
||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ def replace_submodule(
|
|||||||
|
|
||||||
def parse_fine_tuned_lora_name(
|
def parse_fine_tuned_lora_name(
|
||||||
name: str, weights_mapper: Optional["WeightsMapper"] = None
|
name: str, weights_mapper: Optional["WeightsMapper"] = None
|
||||||
) -> tuple[str, bool, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Parse the name of lora weights.
|
"""Parse the name of lora weights.
|
||||||
|
|
||||||
args:
|
args:
|
||||||
@@ -124,7 +124,6 @@ def parse_fine_tuned_lora_name(
|
|||||||
tuple(module_name, is_lora_a):
|
tuple(module_name, is_lora_a):
|
||||||
module_name: the name of the module, e.g. model.dense1,
|
module_name: the name of the module, e.g. model.dense1,
|
||||||
is_lora_a whether the tensor is lora_a or lora_b.
|
is_lora_a whether the tensor is lora_a or lora_b.
|
||||||
is_bias whether the tensor is lora bias.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# LoRA weight qualified name usually starts with `base_model.model.`,
|
# LoRA weight qualified name usually starts with `base_model.model.`,
|
||||||
@@ -146,15 +145,11 @@ def parse_fine_tuned_lora_name(
|
|||||||
parts = name.split(".")
|
parts = name.split(".")
|
||||||
if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
|
if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
|
||||||
new_name = ".".join(parts[start_index:-2])
|
new_name = ".".join(parts[start_index:-2])
|
||||||
return new_name, parts[-2] == "lora_A", False
|
return new_name, parts[-2] == "lora_A"
|
||||||
|
|
||||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||||
new_name = ".".join(parts[start_index:-1])
|
new_name = ".".join(parts[start_index:-1])
|
||||||
return new_name, parts[-1] == "lora_embedding_A", False
|
return new_name, parts[-1] == "lora_embedding_A"
|
||||||
|
|
||||||
if parts[-1] == "bias":
|
|
||||||
new_name = ".".join(parts[start_index:-2])
|
|
||||||
return new_name, False, True
|
|
||||||
|
|
||||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user