Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -21,11 +21,18 @@ def reset_device(reset_default_device):
# Utility shrink and expand operations used as reference implementations.
def sgmv_shrink_for_nslices(
nslices: int, inputs_tensor: torch.Tensor,
lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
num_tokens: int, scaling: float):
nslices: int,
inputs_tensor: torch.Tensor,
lora_weights_lst: list[torch.Tensor],
out_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
prompt_lora_mapping: torch.Tensor,
batches: int,
max_seq_length: int,
num_tokens: int,
scaling: float,
):
"""
Wrapper around torch_ops.sgmv_shrink that handles any nslices.
"""
@@ -44,15 +51,20 @@ def sgmv_shrink_for_nslices(
)
def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
inputs_tensor: torch.Tensor,
lora_weights_lst: list[torch.Tensor],
out_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
prompt_lora_mapping: torch.Tensor, batches: int,
max_seq_length: int, num_tokens: int,
add_inputs: bool) -> None:
def sgmv_expand_for_nslices(
nslices: int,
hidden_size: int,
inputs_tensor: torch.Tensor,
lora_weights_lst: list[torch.Tensor],
out_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
prompt_lora_mapping: torch.Tensor,
batches: int,
max_seq_length: int,
num_tokens: int,
add_inputs: bool,
) -> None:
"""
Wrapper around torch_ops.sgmv_expand that handles any nslices.
"""
@@ -94,10 +106,17 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
_dict_lock = Lock()
def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str, seq_length: int,
scaling: float):
def check_lora_shrink_kernel(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
seq_length: int,
scaling: float,
):
"""
Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
kernels.
@@ -116,14 +135,19 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
max_seq_length, token_nums = data.meta()
# Setup metadata information for SGMV and reference kernels
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
sgmv_meta_args = (
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
)
# Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
lora_meta = LoRAKernelMeta.make(
max_loras=num_loras, max_num_tokens=token_nums, device="cuda"
)
lora_meta.prepare_tensors(data.token_lora_mapping)
ref_out_tensor = data.ref_out_tensor
@@ -154,10 +178,17 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
assert_close(out_tensor, ref_out_tensor)
def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str, seq_length: int,
add_inputs: bool):
def check_lora_expand_kernel(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
seq_length: int,
add_inputs: bool,
):
"""
Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
kernels.
@@ -177,14 +208,19 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
max_seq_length, token_nums = data.meta()
# Setup metadata information for SGMV and reference kernels
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
sgmv_meta_args = (
data.b_seq_start_loc,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
)
# Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
lora_meta = LoRAKernelMeta.make(
max_loras=num_loras, max_num_tokens=token_nums, device="cuda"
)
lora_meta.prepare_tensors(data.token_lora_mapping)
# Setup output tensors
@@ -194,21 +230,25 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
with _dict_lock:
# lora_expand kernel
_LORA_B_PTR_DICT.clear()
triton_ops.lora_expand(data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
triton_ops.lora_expand(
data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs,
)
# Reference
sgmv_expand_for_nslices(nslices,
hidden_size,
data.inputs_tensor,
data.lora_weights,
ref_out_tensor,
*sgmv_meta_args,
add_inputs=add_inputs)
sgmv_expand_for_nslices(
nslices,
hidden_size,
data.inputs_tensor,
data.lora_weights,
ref_out_tensor,
*sgmv_meta_args,
add_inputs=add_inputs,
)
assert_close(out_tensor, ref_out_tensor)
@@ -299,7 +339,7 @@ HIDDEN_SIZES = [
128000,
128256,
]
#The size of TP
# The size of TP
divisibility = [1, 2, 8, 16, 64]
all_hidden_size = []
@@ -331,10 +371,10 @@ DEVICES = [f"cuda:{0}"]
SEED = [0]
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("batches", test_params["batches"])
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
@pytest.mark.parametrize("rank", test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@@ -358,31 +398,35 @@ def test_kernels(
current_platform.seed_everything(seed)
if op_type == "shrink":
check_lora_shrink_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_lora_shrink_kernel(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5,
)
else:
check_lora_expand_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_lora_expand_kernel(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True,
)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("batches", hs_test_params["batches"])
@pytest.mark.parametrize("num_loras", hs_test_params["num_loras"])
@pytest.mark.parametrize("rank", hs_test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", hs_test_params["hidden_sizes"])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@@ -406,22 +450,26 @@ def test_kernels_hidden_size(
current_platform.seed_everything(seed)
if op_type == "shrink":
check_lora_shrink_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_lora_shrink_kernel(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5,
)
else:
check_lora_expand_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_lora_expand_kernel(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True,
)