[Misc] LoRA - Refactor Punica ops tests (#12970)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-02-11 12:56:03 +05:30
committed by GitHub
parent c320ca8edd
commit 78a141d768
4 changed files with 686 additions and 725 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
@@ -106,6 +107,31 @@ def assert_close(a, b):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@dataclass
class PunicaTensors:
inputs_tensor: torch.Tensor
lora_weights: Union[torch.Tensor, List[torch.Tensor]]
our_out_tensor: torch.Tensor
ref_out_tensor: torch.Tensor
b_seq_start_loc: torch.Tensor
prompt_lora_mapping: torch.Tensor
seq_len_tensor: torch.Tensor
token_lora_mapping: torch.Tensor
def meta(self) -> Tuple[int, int]:
"""
Infer max_seq_length and token_nums from the tensors
and return them.
"""
max_seq_length = self.seq_len_tensor.max()
token_nums = self.seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
return max_seq_length, token_nums
def generate_data(
batches,
hidden_size,
@@ -115,7 +141,7 @@ def generate_data(
dtype,
op_type,
device,
):
) -> PunicaTensors:
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
@@ -164,7 +190,8 @@ def generate_data(
indices[current_offset:current_offset +
seq_len_tensor[b_id]].copy_(lora_index)
current_offset += seq_len_tensor[b_id].item()
return (
return PunicaTensors(
inputs_tensor,
lora_weights,
our_out_tensor,
@@ -185,7 +212,7 @@ def generate_data_for_expand_nslices(
dtype,
nslices,
device,
):
) -> PunicaTensors:
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
@@ -222,7 +249,7 @@ def generate_data_for_expand_nslices(
current_offset += seq_len_tensor[b_id].item()
lora_indices_tensor = lora_indices_tensor.to(device)
return (
return PunicaTensors(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
@@ -244,7 +271,7 @@ def generate_data_for_nslices(
dtype,
op_type,
device,
):
) -> PunicaTensors:
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
@@ -302,7 +329,7 @@ def generate_data_for_nslices(
current_offset += seq_len_tensor[b_id].item()
lora_indices_tensor = lora_indices_tensor.to(device)
return (
return PunicaTensors(
inputs_tensor,
lora_weights_lst,
our_out_tensor,