Remove LoRA bias support (#25807)

Signed-off-by: Ashwin Phadke <ashwinphadke12@rediffmail.com>
Signed-off-by: Ashwin Phadke <23502062+ashwin-phadke@users.noreply.github.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Ashwin Phadke
2025-10-10 15:20:33 +05:30
committed by GitHub
parent 3ee202ea1e
commit ab196edefb
20 changed files with 35 additions and 366 deletions

View File

@@ -23,11 +23,6 @@ BADREQUEST_CASES = [
{"r": 1024}, {"r": 1024},
"is greater than max_lora_rank", "is greater than max_lora_rank",
), ),
(
"test_bias",
{"bias": "all"},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {"use_dora": True}, "does not yet support DoRA"), ("test_dora", {"use_dora": True}, "does not yet support DoRA"),
( (
"test_modules_to_save", "test_modules_to_save",

View File

@@ -16,11 +16,6 @@ ERROR_CASES = [
{"r": 1024}, {"r": 1024},
"is greater than max_lora_rank", "is greater than max_lora_rank",
), ),
(
"test_bias",
{"bias": "all"},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {"use_dora": True}, "does not yet support DoRA"), ("test_dora", {"use_dora": True}, "does not yet support DoRA"),
( (
"test_modules_to_save", "test_modules_to_save",

View File

@@ -21,7 +21,6 @@ class LoRANameParserTestConfig(NamedTuple):
name: str name: str
module_name: str module_name: str
is_lora_a: bool is_lora_a: bool
is_bias: bool
weights_mapper: Optional[WeightsMapper] = None weights_mapper: Optional[WeightsMapper] = None
@@ -37,44 +36,37 @@ def test_parse_fine_tuned_lora_name_valid():
"base_model.model.model.embed_tokens.lora_embedding_A", "base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens", "model.embed_tokens",
True, True,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"base_model.model.model.embed_tokens.lora_embedding_B", "base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens", "model.embed_tokens",
False, False,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
True, True,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
False, False,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.layers.9.mlp.down_proj", "language_model.layers.9.mlp.down_proj",
True, True,
False,
), ),
LoRANameParserTestConfig( LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.layers.9.mlp.down_proj", "language_model.layers.9.mlp.down_proj",
False, False,
False,
), ),
# Test with WeightsMapper # Test with WeightsMapper
LoRANameParserTestConfig( LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj", "language_model.model.layers.9.mlp.down_proj",
True, True,
False,
weights_mapper=WeightsMapper( weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."} orig_to_new_prefix={"model.": "language_model.model."}
), ),
@@ -83,7 +75,6 @@ def test_parse_fine_tuned_lora_name_valid():
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj", "language_model.model.layers.9.mlp.down_proj",
False, False,
False,
weights_mapper=WeightsMapper( weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."} orig_to_new_prefix={"model.": "language_model.model."}
), ),
@@ -92,7 +83,6 @@ def test_parse_fine_tuned_lora_name_valid():
"model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj", "language_model.model.layers.9.mlp.down_proj",
True, True,
False,
weights_mapper=WeightsMapper( weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."} orig_to_new_prefix={"model.": "language_model.model."}
), ),
@@ -101,14 +91,13 @@ def test_parse_fine_tuned_lora_name_valid():
"model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj", "language_model.model.layers.9.mlp.down_proj",
False, False,
False,
weights_mapper=WeightsMapper( weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."} orig_to_new_prefix={"model.": "language_model.model."}
), ),
), ),
] ]
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: for name, module_name, is_lora_a, weights_mapper in fixture:
assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name( assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(
name, weights_mapper name, weights_mapper
) )

View File

@@ -70,12 +70,6 @@ class LoRAConfig:
per prompt. When run in offline mode, the lora IDs for n modalities per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities will be automatically assigned to 1-n with the names of the modalities
in alphabetic order.""" in alphabetic order."""
bias_enabled: bool = Field(
default=False,
deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
)
"""[DEPRECATED] Enable bias for LoRA adapters. This option will be
removed in v0.12.0."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
@@ -96,7 +90,7 @@ class LoRAConfig:
factors.append(self.lora_dtype) factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size) factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size) factors.append(self.lora_vocab_padding_size)
factors.append(self.bias_enabled)
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str

View File

@@ -439,7 +439,6 @@ class EngineArgs:
video_pruning_rate: float = MultiModalConfig.video_pruning_rate video_pruning_rate: float = MultiModalConfig.video_pruning_rate
# LoRA fields # LoRA fields
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
@@ -916,7 +915,6 @@ class EngineArgs:
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
help="If True, enable handling of LoRA adapters.", help="If True, enable handling of LoRA adapters.",
) )
lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"])
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
lora_group.add_argument( lora_group.add_argument(
@@ -1515,7 +1513,6 @@ class EngineArgs:
lora_config = ( lora_config = (
LoRAConfig( LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras, default_mm_loras=self.default_mm_loras,

View File

@@ -45,7 +45,6 @@ class BaseLayerWithLoRA(nn.Module):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
"""Overwrites lora tensors at index.""" """Overwrites lora tensors at index."""
... ...

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, cast from typing import Optional
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
@@ -29,7 +29,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.tp_size = self.base_layer.tp_size self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer) self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...] self.output_slices: tuple[int, ...]
self.output_size: int self.output_size: int
self.n_slices: int self.n_slices: int
@@ -86,30 +85,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
) )
for _ in range(self.n_slices) for _ in range(self.n_slices)
) )
if lora_config.bias_enabled:
lora_bias_out_size = lora_b_out_size
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_bias_out_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.output_slices = (self.lora_b_stacked[0].shape[2],) self.output_slices = (self.lora_b_stacked[0].shape[2],)
def reset_lora(self, index: int): def reset_lora(self, index: int):
for s_index in range(self.n_slices): for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0 self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0 self.lora_b_stacked[s_index][index] = 0
if self.lora_config.bias_enabled:
# Make mypy happy
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
self.lora_bias_stacked[s_index][index] = 0
def set_lora( def set_lora(
self, self,
@@ -117,7 +98,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
): ):
# Except for QKVParallelLinearWithLoRA and # Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
@@ -131,8 +111,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True lora_a, non_blocking=True
@@ -140,14 +118,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True lora_b, non_blocking=True
) )
if lora_bias is not None:
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, : lora_bias.shape[0]].copy_(
lora_bias, non_blocking=True
)
def apply( def apply(
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
@@ -162,13 +132,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
x = x.flatten(0, 1) x = x.flatten(0, 1)
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear( lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear(
output, output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.lora_bias_stacked,
1.0,
self.output_slices,
) )
if not current_platform.can_update_inplace(): if not current_platform.can_update_inplace():
output = lora_output output = lora_output

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union, cast from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -32,8 +32,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
== len(layer.lora_b_stacked) == len(layer.lora_b_stacked)
== len(layer.output_slices) == len(layer.output_slices)
) )
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
@@ -61,7 +59,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
output, output,
buffers, buffers,
layer.lora_b_stacked, layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices, layer.output_slices,
offset_start=0, offset_start=0,
add_input=True, add_input=True,
@@ -122,16 +119,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
lora_b = lora_b[start_idx:end_idx, :] lora_b = lora_b[start_idx:end_idx, :]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
# TODO: Fix the slicing logic of bias.
if bias is None:
return bias
shard_size = self.output_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def forward( def forward(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
@@ -238,17 +225,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) )
for output_size in self.output_slices for output_size in self.output_slices
) )
if lora_config.bias_enabled:
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
output_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for output_size in self.output_slices
)
def slice_lora_a( def slice_lora_a(
self, lora_a: list[Union[torch.Tensor, None]] self, lora_a: list[Union[torch.Tensor, None]]
@@ -268,31 +244,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
] ]
return sliced_lora_b return sliced_lora_b
def slice_bias(
self, bias: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)
):
if (bias_i := bias[i]) is not None:
bias[i] = bias_i[shard_size * shard_id : shard_size * (shard_id + 1)]
return bias
def set_lora( def set_lora(
self, self,
index: int, index: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a) lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b) lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
for i in range(self.n_slices): for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None: if (lora_a_i := lora_a[i]) is not None:
@@ -304,16 +267,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1] index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
].copy_(lora_b_i, non_blocking=True) ].copy_(lora_b_i, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
for i in range(self.n_slices):
if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index, 0, : lora_bias_i.shape[0]].copy_(
lora_bias_i, non_blocking=True
)
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
@@ -380,24 +333,6 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
bias_q = bias[
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
* (self.q_shard_id + 1)
]
k_offset = self.q_proj_total_size
bias_k = bias[
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
]
v_offset = k_offset + self.kv_proj_total_size
bias_v = bias[
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
]
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
return bias
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(

View File

@@ -143,7 +143,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union, cast from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -39,9 +39,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias
def forward( def forward(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
@@ -123,16 +120,6 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[start_idx:end_idx, :] lora_b = lora_b[start_idx:end_idx, :]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked)
shard_size = self.lora_bias_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def apply( def apply(
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
@@ -167,7 +154,6 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
output, output,
buffer, buffer,
self.lora_b_stacked, self.lora_b_stacked,
self.lora_bias_stacked,
self.output_slices, self.output_slices,
offset_start=offset_start, offset_start=offset_start,
add_input=True, add_input=True,

View File

@@ -91,7 +91,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
): ):
self.reset_lora(index) self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major, # NOTE self.lora_a_stacked is row-major, and lora_a is col-major,

View File

@@ -21,7 +21,6 @@ class LoRALayerWeights:
lora_alpha: int, lora_alpha: int,
lora_a: torch.Tensor, lora_a: torch.Tensor,
lora_b: torch.Tensor, lora_b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
embeddings_tensor: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None, scaling: Optional[float] = None,
) -> None: ) -> None:
@@ -30,7 +29,6 @@ class LoRALayerWeights:
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
self.lora_a = lora_a self.lora_a = lora_a
self.lora_b = lora_b self.lora_b = lora_b
self.bias = bias
self.embeddings_tensor = embeddings_tensor self.embeddings_tensor = embeddings_tensor
if scaling is None: if scaling is None:
@@ -71,13 +69,13 @@ class LoRALayerWeights:
peft_helper: PEFTHelper, peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights": ) -> "LoRALayerWeights":
# lora_a and lora_b are set to None for config-based construction
return cls( return cls(
module_name, module_name,
peft_helper.r, peft_helper.r,
peft_helper.lora_alpha, peft_helper.lora_alpha,
None, None,
None, None,
None,
embeddings_tensor, embeddings_tensor,
peft_helper.vllm_lora_scaling_factor, peft_helper.vllm_lora_scaling_factor,
) )
@@ -92,7 +90,6 @@ class LoRALayerWeights:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.types.Device, device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None, embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False,
) -> "LoRALayerWeights": ) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros( lora_a = torch.zeros(
@@ -101,12 +98,6 @@ class LoRALayerWeights:
lora_b = torch.zeros( lora_b = torch.zeros(
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
) )
if bias_enabled:
bias = torch.zeros(
[output_dim], dtype=dtype, device=device, pin_memory=pin_memory
)
else:
bias = None
embeddings_tensor = ( embeddings_tensor = (
torch.rand( torch.rand(
@@ -125,7 +116,6 @@ class LoRALayerWeights:
lora_alpha=1, lora_alpha=1,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
bias=bias,
embeddings_tensor=embeddings_tensor, embeddings_tensor=embeddings_tensor,
) )
@@ -140,7 +130,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alphas: list[Optional[int]], lora_alphas: list[Optional[int]],
lora_a: list[Optional[torch.Tensor]], lora_a: list[Optional[torch.Tensor]],
lora_b: list[Optional[torch.Tensor]], lora_b: list[Optional[torch.Tensor]],
bias: Optional[list[Optional[torch.Tensor]]] = None,
scaling: Optional[list[float]] = None, scaling: Optional[list[float]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
@@ -149,7 +138,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0, lora_alpha=0,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
bias=bias,
scaling=scaling, # type: ignore scaling=scaling, # type: ignore
embeddings_tensor=None, embeddings_tensor=None,
) )
@@ -181,7 +169,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras],
[lora.bias if lora is not None else None for lora in loras],
scaling=[ scaling=[
1 if lora is not None else None # type: ignore 1 if lora is not None else None # type: ignore
for lora in loras for lora in loras

View File

@@ -3,7 +3,6 @@
import math import math
import os import os
from collections.abc import Sequence
from typing import Callable, Optional, TypeVar, Union from typing import Callable, Optional, TypeVar, Union
import regex as re import regex as re
@@ -140,7 +139,7 @@ class LoRAModel:
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {} loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items(): for tensor_name, tensor in tensors.items():
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper tensor_name, weights_mapper
) )
if module_name not in loras: if module_name not in loras:
@@ -160,13 +159,7 @@ class LoRAModel:
module_name, peft_helper, lora_embeddings_tensor module_name, peft_helper, lora_embeddings_tensor
) )
if is_bias: if is_lora_a:
loras[module_name].bias = tensor.to(device=device, dtype=dtype)
bias = tensor.to(device=device, dtype=dtype)
if pin_memory:
bias = bias.pin_memory()
loras[module_name].bias = bias
elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
if pin_memory: if pin_memory:
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
@@ -234,9 +227,7 @@ class LoRAModel:
def check_unexpected_modules(modules: dict): def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa for lora_module in modules.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name( module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
lora_module, weights_mapper
)
part_name = module_name.split(".")[-1] part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
@@ -439,23 +430,11 @@ class LoRAModelManager:
module_lora = self._get_lora_layer_weights(lora_model, module_name) module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora: if module_lora:
module_lora.optimize() module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
if (
torch.is_tensor(bias)
or (isinstance(bias, Sequence) and any(b is not None for b in bias))
) and not self.lora_config.bias_enabled:
module_lora.bias = None
raise ValueError(
f"Adapter bias cannot be used for {module_name}"
" without --enable-lora-bias."
)
module.set_lora( module.set_lora(
index, index,
module_lora.lora_a, module_lora.lora_a,
module_lora.lora_b, module_lora.lora_b,
module_lora.embeddings_tensor, module_lora.embeddings_tensor,
module_lora.bias,
) )
else: else:
module.reset_lora(index) module.reset_lora(index)
@@ -581,7 +560,6 @@ class LoRAModelManager:
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}) model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
bias_enabled = self.lora_config.bias_enabled
if ( if (
not self._match_target_modules(module_name) not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA) or not isinstance(module, BaseLayerWithLoRA)
@@ -616,7 +594,6 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype, module.lora_a_stacked[0].dtype,
"cpu", "cpu",
embeddings_tensor_dim=embeddings_tensor_dim, embeddings_tensor_dim=embeddings_tensor_dim,
bias_enabled=bias_enabled,
) )
else: else:
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
@@ -626,7 +603,6 @@ class LoRAModelManager:
rank, rank,
module.lora_a_stacked[0].dtype, module.lora_a_stacked[0].dtype,
"cpu", "cpu",
bias_enabled=bias_enabled,
) )
else: else:
parts = module_name.split(".") parts = module_name.split(".")
@@ -640,7 +616,6 @@ class LoRAModelManager:
rank, rank,
module.lora_a_stacked[i].dtype, module.lora_a_stacked[i].dtype,
"cpu", "cpu",
bias_enabled=bias_enabled,
) )
subloras.append(lora) subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras) lora = PackedLoRALayerWeights.pack(subloras)

View File

@@ -29,7 +29,7 @@ class PEFTHelper:
lora_alpha: int lora_alpha: int
target_modules: Union[list[str], str] target_modules: Union[list[str], str]
bias: Literal["none", "all", "lora_only"] = field(default="none") bias: Literal["none"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None) modules_to_save: Optional[list[str]] = field(default=None)
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
use_rslora: bool = field(default=False) use_rslora: bool = field(default=False)
@@ -122,7 +122,7 @@ class PEFTHelper:
f"LoRA rank {self.r} is greater than max_lora_rank" f"LoRA rank {self.r} is greater than max_lora_rank"
f" {lora_config.max_lora_rank}." f" {lora_config.max_lora_rank}."
) )
if self.bias != "none" and not lora_config.bias_enabled: if self.bias != "none":
error_msg.append("Adapter bias cannot be used without bias_enabled.") error_msg.append("Adapter bias is not supported.")
if error_msg: if error_msg:
raise ValueError(f"{' '.join(error_msg)}") raise ValueError(f"{' '.join(error_msg)}")

View File

@@ -60,14 +60,13 @@ class PunicaWrapperABC(ABC):
y: torch.Tensor, y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
""" """
raise NotImplementedError raise NotImplementedError
@@ -93,7 +92,6 @@ class PunicaWrapperABC(ABC):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
@@ -222,38 +220,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
self.token_nums = token_nums self.token_nums = token_nums
self.no_lora = no_lora self.no_lora = no_lora
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: tuple[int, ...],
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left : offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@property @property
def prefill_metadata( def prefill_metadata(
self, self,
@@ -365,29 +331,25 @@ class PunicaWrapperBase(PunicaWrapperABC):
y: torch.Tensor, y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
offset = offset_start offset = offset_start
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0 offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
@@ -427,7 +389,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
@@ -444,14 +405,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.

View File

@@ -199,38 +199,30 @@ class PunicaWrapperCPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
offset_left = offset_start offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(
self.token_lora_indices, y, output_slices, lora_bias_stacked
)
for slice_idx in range(len(lora_b_stacked)): for slice_idx in range(len(lora_b_stacked)):
self._apply_expand( self._apply_expand(
y, y,
@@ -276,7 +268,6 @@ class PunicaWrapperCPU(PunicaWrapperBase):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
@@ -293,25 +284,19 @@ class PunicaWrapperCPU(PunicaWrapperBase):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(
self.token_lora_indices, y, output_slices, lora_bias_stacked
)
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
@@ -323,7 +308,7 @@ class PunicaWrapperCPU(PunicaWrapperBase):
) )
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand( self.add_expand(
y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
) )
def add_lora_logits( def add_lora_logits(

View File

@@ -101,36 +101,29 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
if lora_bias_stacked is not None:
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0))
self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked)
assert x.ndim == 3 assert x.ndim == 3
assert x.size(0) == len(output_slices) assert x.size(0) == len(output_slices)
@@ -183,7 +176,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
@@ -200,26 +192,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None. buffer (Optional[torch.Tensor]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0))
y = self._apply_bias(
token_lora_indices, y, output_slices, lora_bias_stacked
)
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
@@ -241,7 +225,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y, y,
buffer, # type: ignore buffer, # type: ignore
lora_b_stacked, lora_b_stacked,
None,
output_slices, output_slices,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,

View File

@@ -139,28 +139,24 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
@@ -168,10 +164,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
offset_left = 0 offset_left = 0
if lora_bias_stacked is not None:
y = self._apply_bias(
self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked
)
for slice_idx in range(len(lora_b_stacked)): for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice( y = self.expand_slice(
y, y,
@@ -214,7 +206,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
@@ -231,25 +222,19 @@ class PunicaWrapperTPU(PunicaWrapperBase):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will not be changed in-place. y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E) x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(
self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked
)
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
@@ -261,7 +246,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
) )
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
return self.add_expand( return self.add_expand(
y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
) )
def add_lora_logits( def add_lora_logits(
@@ -299,43 +284,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
return y.view_as(y_org) return y.view_as(y_org)
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: tuple[int, ...],
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias = torch.where(indices[:, None] == -1, 0, bias)
bias = F.pad(
bias, (offset_left, output.shape[1] - (offset_left + slice), 0, 0)
)
output += bias
offset_left += slice
return output.view_as(org_output)
# This performs the same tensor ops as the base method, except it does them # This performs the same tensor ops as the base method, except it does them
# on the CPU then transfers the results to the TPU # on the CPU then transfers the results to the TPU
def _update_base_metadata( def _update_base_metadata(

View File

@@ -108,36 +108,29 @@ class PunicaWrapperXPU(PunicaWrapperBase):
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM for multiple slices of lora_b.
Semantics: Semantics:
for i in range(len(lora_b_stacked)): for i in range(len(lora_b_stacked)):
slice = output_slices[i] slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
lora_bias_stacked[i]
offset += slice offset += slice
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
if lora_bias_stacked is not None:
token_lora_indices = self._get_token_lora_indices(y)
self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked)
assert x.ndim == 3 assert x.ndim == 3
assert x.size(0) == len(output_slices) assert x.size(0) == len(output_slices)
@@ -184,7 +177,6 @@ class PunicaWrapperXPU(PunicaWrapperBase):
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: tuple[int, ...], output_slices: tuple[int, ...],
*, *,
@@ -201,26 +193,19 @@ class PunicaWrapperXPU(PunicaWrapperBase):
@ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :]
* scale * scale
).squeeze(0)+lora_bias_stacked[i] ).squeeze(0)
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None. buffer (Optional[torch.Tensor]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
token_lora_indices = self._get_token_lora_indices(y)
y = self._apply_bias(
token_lora_indices, y, output_slices, lora_bias_stacked
)
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
@@ -242,7 +227,6 @@ class PunicaWrapperXPU(PunicaWrapperBase):
y, y,
buffer, # type: ignore buffer, # type: ignore
lora_b_stacked, lora_b_stacked,
None,
output_slices, output_slices,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,

View File

@@ -112,7 +112,7 @@ def replace_submodule(
def parse_fine_tuned_lora_name( def parse_fine_tuned_lora_name(
name: str, weights_mapper: Optional["WeightsMapper"] = None name: str, weights_mapper: Optional["WeightsMapper"] = None
) -> tuple[str, bool, bool]: ) -> tuple[str, bool]:
"""Parse the name of lora weights. """Parse the name of lora weights.
args: args:
@@ -124,7 +124,6 @@ def parse_fine_tuned_lora_name(
tuple(module_name, is_lora_a): tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1, module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b. is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
""" """
# LoRA weight qualified name usually starts with `base_model.model.`, # LoRA weight qualified name usually starts with `base_model.model.`,
@@ -146,15 +145,11 @@ def parse_fine_tuned_lora_name(
parts = name.split(".") parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
new_name = ".".join(parts[start_index:-2]) new_name = ".".join(parts[start_index:-2])
return new_name, parts[-2] == "lora_A", False return new_name, parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
new_name = ".".join(parts[start_index:-1]) new_name = ".".join(parts[start_index:-1])
return new_name, parts[-1] == "lora_embedding_A", False return new_name, parts[-1] == "lora_embedding_A"
if parts[-1] == "bias":
new_name = ".".join(parts[start_index:-2])
return new_name, False, True
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")