Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -21,11 +21,18 @@ def reset_device(reset_default_device):
|
||||
|
||||
# Utility shrink and expand operations used as reference implementations.
|
||||
def sgmv_shrink_for_nslices(
|
||||
nslices: int, inputs_tensor: torch.Tensor,
|
||||
lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
|
||||
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
|
||||
num_tokens: int, scaling: float):
|
||||
nslices: int,
|
||||
inputs_tensor: torch.Tensor,
|
||||
lora_weights_lst: list[torch.Tensor],
|
||||
out_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
prompt_lora_mapping: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
num_tokens: int,
|
||||
scaling: float,
|
||||
):
|
||||
"""
|
||||
Wrapper around torch_ops.sgmv_shrink that handles any nslices.
|
||||
"""
|
||||
@@ -44,15 +51,20 @@ def sgmv_shrink_for_nslices(
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||
inputs_tensor: torch.Tensor,
|
||||
lora_weights_lst: list[torch.Tensor],
|
||||
out_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
prompt_lora_mapping: torch.Tensor, batches: int,
|
||||
max_seq_length: int, num_tokens: int,
|
||||
add_inputs: bool) -> None:
|
||||
def sgmv_expand_for_nslices(
|
||||
nslices: int,
|
||||
hidden_size: int,
|
||||
inputs_tensor: torch.Tensor,
|
||||
lora_weights_lst: list[torch.Tensor],
|
||||
out_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
prompt_lora_mapping: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
num_tokens: int,
|
||||
add_inputs: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Wrapper around torch_ops.sgmv_expand that handles any nslices.
|
||||
"""
|
||||
@@ -94,10 +106,17 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||
_dict_lock = Lock()
|
||||
|
||||
|
||||
def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int,
|
||||
dtype: torch.dtype, device: str, seq_length: int,
|
||||
scaling: float):
|
||||
def check_lora_shrink_kernel(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seq_length: int,
|
||||
scaling: float,
|
||||
):
|
||||
"""
|
||||
Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
|
||||
kernels.
|
||||
@@ -116,14 +135,19 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
|
||||
max_seq_length, token_nums = data.meta()
|
||||
|
||||
# Setup metadata information for SGMV and reference kernels
|
||||
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
|
||||
data.prompt_lora_mapping, batches, max_seq_length,
|
||||
token_nums)
|
||||
sgmv_meta_args = (
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
)
|
||||
|
||||
# Setup metadata information for the LoRA kernel.
|
||||
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
|
||||
max_num_tokens=token_nums,
|
||||
device='cuda')
|
||||
lora_meta = LoRAKernelMeta.make(
|
||||
max_loras=num_loras, max_num_tokens=token_nums, device="cuda"
|
||||
)
|
||||
lora_meta.prepare_tensors(data.token_lora_mapping)
|
||||
|
||||
ref_out_tensor = data.ref_out_tensor
|
||||
@@ -154,10 +178,17 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
|
||||
assert_close(out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int,
|
||||
dtype: torch.dtype, device: str, seq_length: int,
|
||||
add_inputs: bool):
|
||||
def check_lora_expand_kernel(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seq_length: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
"""
|
||||
Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
|
||||
kernels.
|
||||
@@ -177,14 +208,19 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
|
||||
max_seq_length, token_nums = data.meta()
|
||||
|
||||
# Setup metadata information for SGMV and reference kernels
|
||||
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
|
||||
data.prompt_lora_mapping, batches, max_seq_length,
|
||||
token_nums)
|
||||
sgmv_meta_args = (
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
)
|
||||
|
||||
# Setup metadata information for the LoRA kernel.
|
||||
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
|
||||
max_num_tokens=token_nums,
|
||||
device='cuda')
|
||||
lora_meta = LoRAKernelMeta.make(
|
||||
max_loras=num_loras, max_num_tokens=token_nums, device="cuda"
|
||||
)
|
||||
lora_meta.prepare_tensors(data.token_lora_mapping)
|
||||
|
||||
# Setup output tensors
|
||||
@@ -194,21 +230,25 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
|
||||
with _dict_lock:
|
||||
# lora_expand kernel
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
triton_ops.lora_expand(data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
out_tensor,
|
||||
*lora_meta.meta_args(token_nums=token_nums),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs)
|
||||
triton_ops.lora_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
out_tensor,
|
||||
*lora_meta.meta_args(token_nums=token_nums),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
# Reference
|
||||
sgmv_expand_for_nslices(nslices,
|
||||
hidden_size,
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
ref_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
add_inputs=add_inputs)
|
||||
sgmv_expand_for_nslices(
|
||||
nslices,
|
||||
hidden_size,
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
ref_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
assert_close(out_tensor, ref_out_tensor)
|
||||
|
||||
@@ -299,7 +339,7 @@ HIDDEN_SIZES = [
|
||||
128000,
|
||||
128256,
|
||||
]
|
||||
#The size of TP
|
||||
# The size of TP
|
||||
divisibility = [1, 2, 8, 16, 64]
|
||||
|
||||
all_hidden_size = []
|
||||
@@ -331,10 +371,10 @@ DEVICES = [f"cuda:{0}"]
|
||||
SEED = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("batches", test_params["batches"])
|
||||
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
|
||||
@pytest.mark.parametrize("rank", test_params["max_ranks"])
|
||||
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
|
||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@@ -358,31 +398,35 @@ def test_kernels(
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_lora_shrink_kernel(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
check_lora_shrink_kernel(
|
||||
batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5,
|
||||
)
|
||||
else:
|
||||
check_lora_expand_kernel(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
check_lora_expand_kernel(
|
||||
batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("batches", hs_test_params["batches"])
|
||||
@pytest.mark.parametrize("num_loras", hs_test_params["num_loras"])
|
||||
@pytest.mark.parametrize("rank", hs_test_params["max_ranks"])
|
||||
@pytest.mark.parametrize("hidden_size", hs_test_params["hidden_sizes"])
|
||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@@ -406,22 +450,26 @@ def test_kernels_hidden_size(
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_lora_shrink_kernel(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
check_lora_shrink_kernel(
|
||||
batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5,
|
||||
)
|
||||
else:
|
||||
check_lora_expand_kernel(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
check_lora_expand_kernel(
|
||||
batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user