[Misc] LoRA - Refactor Punica ops tests (#12970)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
c320ca8edd
commit
78a141d768
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -106,6 +107,31 @@ def assert_close(a, b):
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PunicaTensors:
|
||||
inputs_tensor: torch.Tensor
|
||||
lora_weights: Union[torch.Tensor, List[torch.Tensor]]
|
||||
our_out_tensor: torch.Tensor
|
||||
ref_out_tensor: torch.Tensor
|
||||
b_seq_start_loc: torch.Tensor
|
||||
prompt_lora_mapping: torch.Tensor
|
||||
seq_len_tensor: torch.Tensor
|
||||
token_lora_mapping: torch.Tensor
|
||||
|
||||
def meta(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Infer max_seq_length and token_nums from the tensors
|
||||
and return them.
|
||||
"""
|
||||
max_seq_length = self.seq_len_tensor.max()
|
||||
token_nums = self.seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
return max_seq_length, token_nums
|
||||
|
||||
|
||||
def generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
@@ -115,7 +141,7 @@ def generate_data(
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
):
|
||||
) -> PunicaTensors:
|
||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||
(batches, )).to(device)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
@@ -164,7 +190,8 @@ def generate_data(
|
||||
indices[current_offset:current_offset +
|
||||
seq_len_tensor[b_id]].copy_(lora_index)
|
||||
current_offset += seq_len_tensor[b_id].item()
|
||||
return (
|
||||
|
||||
return PunicaTensors(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
@@ -185,7 +212,7 @@ def generate_data_for_expand_nslices(
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
):
|
||||
) -> PunicaTensors:
|
||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||
(batches, )).to(device)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
@@ -222,7 +249,7 @@ def generate_data_for_expand_nslices(
|
||||
current_offset += seq_len_tensor[b_id].item()
|
||||
|
||||
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||
return (
|
||||
return PunicaTensors(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
@@ -244,7 +271,7 @@ def generate_data_for_nslices(
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
):
|
||||
) -> PunicaTensors:
|
||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||
(batches, )).to(device)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
@@ -302,7 +329,7 @@ def generate_data_for_nslices(
|
||||
current_offset += seq_len_tensor[b_id].item()
|
||||
|
||||
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||
return (
|
||||
return PunicaTensors(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
|
||||
Reference in New Issue
Block a user