[Misc] Remove LoRA log (#15388)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -2373,12 +2373,6 @@ class LoRAConfig:
|
|||||||
self.lora_dtype = model_config.dtype
|
self.lora_dtype = model_config.dtype
|
||||||
elif isinstance(self.lora_dtype, str):
|
elif isinstance(self.lora_dtype, str):
|
||||||
self.lora_dtype = getattr(torch, self.lora_dtype)
|
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||||
if model_config.quantization and model_config.quantization not in [
|
|
||||||
"awq", "gptq"
|
|
||||||
]:
|
|
||||||
# TODO support marlin
|
|
||||||
logger.warning("%s quantization is not tested with LoRA yet.",
|
|
||||||
model_config.quantization)
|
|
||||||
|
|
||||||
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
||||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||||
|
|||||||
@@ -78,10 +78,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
...], scale: float, **kwargs):
|
...], scale: float, **kwargs):
|
||||||
"""
|
"""
|
||||||
Performs GEMM for multiple slices of lora_a.
|
Performs GEMM for multiple slices of lora_a.
|
||||||
When `is_prefill is` true, it indicates that it is currently the
|
|
||||||
prefill stage, and the `_shrink_prefill` function should be called.
|
|
||||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
|
||||||
should be called.
|
|
||||||
|
|
||||||
Semantics:
|
Semantics:
|
||||||
for i in range(len(lora_a_stacked)):
|
for i in range(len(lora_a_stacked)):
|
||||||
@@ -129,7 +125,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||||
bias's weight
|
bias's weight
|
||||||
output_slices (Tuple[int, ...]): Every slice's size
|
output_slices (Tuple[int, ...]): Every slice's size
|
||||||
add_inputs (bool): Defaults to True.
|
add_inputs (bool): Defaults to True.
|
||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
@@ -226,7 +222,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
r = lora_b_stacked[0].size(-1)
|
r = lora_b_stacked[0].size(-1)
|
||||||
# We set the buffer to be float32 by default ,refer to:
|
# We set the buffer to be float32 by default, refer to:
|
||||||
# https://github.com/triton-lang/triton/issues/1387
|
# https://github.com/triton-lang/triton/issues/1387
|
||||||
buffer = torch.zeros( # type: ignore
|
buffer = torch.zeros( # type: ignore
|
||||||
(len(output_slices), x.size(0), r),
|
(len(output_slices), x.size(0), r),
|
||||||
@@ -268,16 +264,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
y (torch.Tensor): Output tensor.
|
y (torch.Tensor): Output tensor.
|
||||||
x (torch.Tensor): Input tensor.
|
x (torch.Tensor): Input tensor.
|
||||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
buffer (Optional[torch.Tensor]):Default to None.
|
buffer (Optional[torch.Tensor]): Default to None.
|
||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
r = lora_b_stacked.size(-1)
|
r = lora_b_stacked.size(-1)
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
# We set the buffer to be float32 by default ,refer to:
|
# We set the buffer to be float32 by default, refer to:
|
||||||
# https://github.com/triton-lang/triton/issues/1387
|
# https://github.com/triton-lang/triton/issues/1387
|
||||||
buffer = torch.zeros((x.size(0), r),
|
buffer = torch.zeros((x.size(0), r),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
|||||||
Reference in New Issue
Block a user