Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -18,17 +18,18 @@ import pandas as pd
def largest_dist_from_leaf(node: dict, depth: int = 0):
if len(node["children"]) == 0:
return depth
return max([
largest_dist_from_leaf(child, depth=depth + 1)
for child in node["children"]
])
return max(
[largest_dist_from_leaf(child, depth=depth + 1) for child in node["children"]]
)
def get_entries_at_depth(depth: int,
entries_and_traces: list[tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=()):
def get_entries_at_depth(
depth: int,
entries_and_traces: list[tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=(),
):
# assert that the query is at kernel or module level
assert depth == -1 or depth == -2
@@ -40,21 +41,18 @@ def get_entries_at_depth(depth: int,
if largest_dist_from_leaf(node) == (abs(depth) - 1):
entries_and_traces.append((node["entry"], trace))
trace = (node["entry"]["name"], ) + trace
trace = (node["entry"]["name"],) + trace
for child in node["children"]:
get_entries_at_depth(depth,
entries_and_traces,
child,
curr_depth=curr_depth + 1,
trace=trace)
get_entries_at_depth(
depth, entries_and_traces, child, curr_depth=curr_depth + 1, trace=trace
)
def fold_nodes(root: dict, nodes_to_fold: list[str]):
stack: list[dict] = [root]
while len(stack) != 0:
node = stack.pop()
if node['entry']['name'] in nodes_to_fold:
if node["entry"]["name"] in nodes_to_fold:
node["children"] = []
continue
for child in node["children"]:
@@ -76,9 +74,7 @@ def trim_string_back(string: str, width: int) -> str:
def shorten_plot_legend_strings(legend, max_char_len: int):
for t in legend.get_texts():
t.set_text(
trim_string_back(abbreviate_known_names(t.get_text()),
max_char_len))
t.set_text(trim_string_back(abbreviate_known_names(t.get_text()), max_char_len))
def abbreviate_known_names(name: str) -> str:
@@ -108,15 +104,21 @@ def attempt_to_make_names_unique(entries_and_traces):
names.add(entry["name"])
for name in non_unique_names:
entries_and_traces_with_name = [(entry, trace)
for entry, trace in entries_and_traces
if entry["name"] == name]
entries_and_traces_with_name = [
(entry, trace)
for entry, trace in entries_and_traces
if entry["name"] == name
]
zipped_traces = list(
zip(*[trace for _, trace in entries_and_traces_with_name]))
zipped_traces = list(zip(*[trace for _, trace in entries_and_traces_with_name]))
first_trace_difference = next(
(i for i, trace_eles in enumerate(zipped_traces)
if not all_the_same(trace_eles)), None)
(
i
for i, trace_eles in enumerate(zipped_traces)
if not all_the_same(trace_eles)
),
None,
)
if first_trace_difference is None:
# can't create a unique name, leave the names as they
@@ -124,34 +126,32 @@ def attempt_to_make_names_unique(entries_and_traces):
continue
for entry, trace in entries_and_traces_with_name:
entry["name"] = " <- ".join((entry["name"], ) +
trace[:first_trace_difference + 1])
entry["name"] = " <- ".join(
(entry["name"],) + trace[: first_trace_difference + 1]
)
## Operation grouping utils ####
'''
"""
Group operations in the given dataframe by some high-level ops like,
- gemms
- attention
- rms_norm
etc.
'''
"""
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def is_rms_norm(op_name: str):
if "rms_norm_kernel" in op_name:
return True
def is_attention_block(op_name: str):
if "flash_fwd" in op_name or \
"reshape_and_cache_flash_kernel" in op_name:
if "flash_fwd" in op_name or "reshape_and_cache_flash_kernel" in op_name:
return True
def is_quant(op_name: str):
if "scaled_fp8_quant" in op_name or \
"scaled_int8_quant" in op_name:
if "scaled_fp8_quant" in op_name or "scaled_int8_quant" in op_name:
return True
# LoRA ops
@@ -168,24 +168,27 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
return "bgmv_expand" in op_name
def is_cutlass_gemm_op(op_name: str):
return "void cutlass::Kernel" in op_name or \
"void cutlass::device_kernel" in op_name
return (
"void cutlass::Kernel" in op_name
or "void cutlass::device_kernel" in op_name
)
def is_gemm_op(op_name: str):
if is_quant(op_name):
return False
return is_cutlass_gemm_op(op_name) or \
"xmma_gemm" in op_name or \
"gemv2T_kernel" in op_name or \
"splitKreduce" in op_name or \
"s16816gemm" in op_name
return (
is_cutlass_gemm_op(op_name)
or "xmma_gemm" in op_name
or "gemv2T_kernel" in op_name
or "splitKreduce" in op_name
or "s16816gemm" in op_name
)
def is_elementwise_op(op_name: str):
return "elementwise_kernel" in op_name
def is_mem_op(op_name: str):
return "memcpy" in op_name.lower() or \
"memset" in op_name.lower()
return "memcpy" in op_name.lower() or "memset" in op_name.lower()
def is_vocab_embedding_op(op_name: str):
return "vocabparallelembed" in op_name.lower()
@@ -195,17 +198,15 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
return "nccl" in op_name.lower()
def is_nccl_all_reduce(op_name: str):
return is_nccl_op(op_name) and \
("all_reduce" in op_name.lower() or \
"allreduce" in op_name.lower())
return is_nccl_op(op_name) and (
"all_reduce" in op_name.lower() or "allreduce" in op_name.lower()
)
def is_nccl_gather(op_name: str):
return is_nccl_op(op_name) and \
"gather" in op_name.lower()
return is_nccl_op(op_name) and "gather" in op_name.lower()
def is_nccl_broadcast(op_name: str):
return is_nccl_op(op_name) and \
"broadcast" in op_name.lower()
return is_nccl_op(op_name) and "broadcast" in op_name.lower()
# Reduce ops types
def is_cross_device_reduce_1stage(op_name: str):
@@ -269,114 +270,122 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
cross_device_reduce_1stage_ops = list(
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
filter(lambda x: is_cross_device_reduce_1stage(x), ops)
)
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
cross_device_reduce_2stage_ops = list(
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
filter(lambda x: is_cross_device_reduce_2stage(x), ops)
)
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
custom_ar_all_reduce_ops = list(
filter(lambda x: is_custom_ar_all_reduce(x), ops))
custom_ar_all_reduce_ops = list(filter(lambda x: is_custom_ar_all_reduce(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
if len(attention_ops):
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
trace_df["attention"] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops):
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
trace_df["quant_ops"] = trace_df[quant_ops].agg("sum", axis=1)
if len(sgmv_shrink_ops):
trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum",
axis=1)
trace_df["sgmv_shrink_ops"] = trace_df[sgmv_shrink_ops].agg("sum", axis=1)
if len(sgmv_expand_ops):
trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum",
axis=1)
trace_df["sgmv_expand_ops"] = trace_df[sgmv_expand_ops].agg("sum", axis=1)
if len(bgmv_shrink_ops):
trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum",
axis=1)
trace_df["bgmv_shrink_ops"] = trace_df[bgmv_shrink_ops].agg("sum", axis=1)
if len(bgmv_expand_ops):
trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum",
axis=1)
trace_df["bgmv_expand_ops"] = trace_df[bgmv_expand_ops].agg("sum", axis=1)
if len(cutlass_gemm_ops):
trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum",
axis=1)
trace_df["cutlass_gemm_ops"] = trace_df[cutlass_gemm_ops].agg("sum", axis=1)
if len(gemm_ops):
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
trace_df["gemm_ops"] = trace_df[gemm_ops].agg("sum", axis=1)
if len(rms_norm_ops):
trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1)
trace_df["rms_norm_ops"] = trace_df[rms_norm_ops].agg("sum", axis=1)
if len(vocab_embed_ops):
trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum",
axis=1)
trace_df["vocab_embed_ops"] = trace_df[vocab_embed_ops].agg("sum", axis=1)
if len(mem_ops):
trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1)
trace_df["mem_ops"] = trace_df[mem_ops].agg("sum", axis=1)
if len(elementwise_ops):
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
axis=1)
trace_df["elementwise_ops"] = trace_df[elementwise_ops].agg("sum", axis=1)
if len(nccl_all_reduce_ops):
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1)
trace_df["nccl_all_reduce_ops"] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1
)
if len(nccl_gather_ops):
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
axis=1)
trace_df["nccl_gather_ops"] = trace_df[nccl_gather_ops].agg("sum", axis=1)
if len(nccl_broadcast_ops):
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
"sum", axis=1)
trace_df["nccl_broadcast_ops"] = trace_df[nccl_broadcast_ops].agg("sum", axis=1)
if len(nccl_other_ops):
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
axis=1)
trace_df["nccl_other_ops"] = trace_df[nccl_other_ops].agg("sum", axis=1)
if len(cross_device_reduce_1stage_ops):
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
cross_device_reduce_1stage_ops].agg("sum", axis=1)
trace_df["cross_device_reduce_1stage_ops"] = trace_df[
cross_device_reduce_1stage_ops
].agg("sum", axis=1)
if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1)
trace_df["cross_device_reduce_2stage_ops"] = trace_df[
cross_device_reduce_2stage_ops
].agg("sum", axis=1)
if len(custom_ar_all_reduce_ops):
trace_df['custom_ar_all_reduce_ops'] = trace_df[
custom_ar_all_reduce_ops].agg("sum", axis=1)
trace_df["custom_ar_all_reduce_ops"] = trace_df[custom_ar_all_reduce_ops].agg(
"sum", axis=1
)
if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)
trace_df["reduce_kernel_ops"] = trace_df[reduce_kernel_ops].agg("sum", axis=1)
trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops +
sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops +
cutlass_gemm_ops + gemm_ops + rms_norm_ops +
vocab_embed_ops + mem_ops + elementwise_ops +
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
trace_df.drop(
attention_ops
+ quant_ops
+ sgmv_shrink_ops
+ sgmv_expand_ops
+ bgmv_shrink_ops
+ bgmv_expand_ops
+ cutlass_gemm_ops
+ gemm_ops
+ rms_norm_ops
+ vocab_embed_ops
+ mem_ops
+ elementwise_ops
+ nccl_all_reduce_ops
+ nccl_gather_ops
+ nccl_broadcast_ops
+ nccl_other_ops
+ cross_device_reduce_1stage_ops
+ cross_device_reduce_2stage_ops
+ custom_ar_all_reduce_ops
+ reduce_kernel_ops,
axis=1,
inplace=True,
)
return trace_df
## Data plotting utils ####
def plot_trace_df(traces_df: pd.DataFrame,
plot_metric: str,
plot_title: str,
output: Optional[Path] = None):
def plot_trace_df(
traces_df: pd.DataFrame,
plot_metric: str,
plot_title: str,
output: Optional[Path] = None,
):
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
phase_df = traces_df.query(f'phase == "{phase}"')
descs = phase_df['phase_desc'].to_list()
descs = phase_df["phase_desc"].to_list()
assert all([desc == descs[0] for desc in descs])
return descs[0]
phases = traces_df['phase'].unique()
phases = traces_df["phase"].unique()
phase_descs = [get_phase_description(traces_df, p) for p in phases]
traces_df = traces_df.pivot_table(index="phase",
columns="name",
values=plot_metric,
aggfunc="sum")
traces_df = traces_df.pivot_table(
index="phase", columns="name", values=plot_metric, aggfunc="sum"
)
traces_df = group_trace_by_operations(traces_df)
@@ -396,20 +405,19 @@ def plot_trace_df(traces_df: pd.DataFrame,
# Write the values as text on the bars
for bar in ax.patches:
if bar.get_height() != 0:
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2 + bar.get_y(),
f"{round(bar.get_height(), 2)}",
ha='center',
color='w',
weight='bold',
size=5)
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2 + bar.get_y(),
f"{round(bar.get_height(), 2)}",
ha="center",
color="w",
weight="bold",
size=5,
)
# Setup legend
handles, labels = plt.gca().get_legend_handles_labels()
legend = fig.legend(handles,
labels,
loc='center left',
bbox_to_anchor=(1, 1))
legend = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 1))
shorten_plot_legend_strings(legend, 50)
# Setup labels and title
@@ -417,21 +425,20 @@ def plot_trace_df(traces_df: pd.DataFrame,
ax.set_ylabel(plot_metric)
plt.suptitle(plot_title)
plt.savefig(output, bbox_inches='tight')
plt.savefig(output, bbox_inches="tight")
print("Created: ", output)
def main(
json_trace: Path,
output_directory: Path,
depth: int, # Fetch/Plot operations at this depth of the Json tree
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: list[str]):
json_trace: Path,
output_directory: Path,
depth: int, # Fetch/Plot operations at this depth of the Json tree
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: list[str],
):
def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame:
def get_entries_and_traces(key: str):
entries_and_traces: list[tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]:
@@ -441,16 +448,14 @@ def main(
get_entries_at_depth(depth, entries_and_traces, root)
return entries_and_traces
def keep_only_top_entries(df: pd.DataFrame,
metric: str,
top_k: int = 9) -> pd.DataFrame:
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index,
["name"]] = "others"
def keep_only_top_entries(
df: pd.DataFrame, metric: str, top_k: int = 9
) -> pd.DataFrame:
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others"
return df
def get_phase_description(key: str) -> str:
num_running_seqs = profile_json[key]['metadata'][
'num_running_seqs']
num_running_seqs = profile_json[key]["metadata"]["num_running_seqs"]
if num_running_seqs is not None:
return f"{key}-seqs-{num_running_seqs}"
else:
@@ -466,20 +471,24 @@ def main(
# To pandas dataframe
trace_dfs = list(
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0),
traces))
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), traces)
)
# Respect top_k
if top_k:
trace_dfs = list(
map(
lambda trace_df: keep_only_top_entries(
trace_df, "cuda_time_us", top_k), trace_dfs))
trace_df, "cuda_time_us", top_k
),
trace_dfs,
)
)
# Fill in information about the step-keys
for trace_df, step_key in zip(trace_dfs, step_keys):
trace_df['phase'] = step_key
trace_df['phase_desc'] = get_phase_description(step_key)
trace_df["phase"] = step_key
trace_df["phase_desc"] = get_phase_description(step_key)
# Combine all data frames so they can be put in a single plot
traces_df = pd.concat(trace_dfs)
@@ -492,17 +501,23 @@ def main(
def make_plot_title_suffix(profile_json: dict) -> str:
context = profile_json["context"]
sparsity = context.get('sparsity', None)
run_type = \
f'Run {context["num_steps"]} steps' if context['num_steps'] else \
(f'Complete {context["complete_num_requests_per_step"]} per '
f'step; Run till completion')
return (f"{context['engine_args']['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
f"Run Type: {run_type}")
sparsity = context.get("sparsity", None)
run_type = (
f"Run {context['num_steps']} steps"
if context["num_steps"]
else (
f"Complete {context['complete_num_requests_per_step']} per "
f"step; Run till completion"
)
)
return (
f"{context['engine_args']['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
f"Run Type: {run_type}"
)
profile_json = None
with open(json_trace) as f:
@@ -511,14 +526,14 @@ def main(
# Get all `llm.generate.step()` profile
step_traces = list(profile_json.keys())
assert (step_traces[0] == 'context')
assert step_traces[0] == "context"
step_traces = step_traces[1:] # have only prefill and decodes
prefills = list(filter(lambda x: "prefill" in x, step_traces))
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
assert len(prefills) + len(all_decodes) == len(step_traces)
assert len(prefills) == 1
decodes = all_decodes[::args.step_plot_interval]
decodes = all_decodes[:: args.step_plot_interval]
if decodes[-1] != all_decodes[-1]:
# Always have the last decode
decodes.append(all_decodes[-1])
@@ -528,48 +543,63 @@ def main(
plot_title_suffix = make_plot_title_suffix(profile_json)
plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix,
output_directory / Path("prefill.png"))
plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix,
output_directory / Path("decode_steps.png"))
plot_trace_df(
prefill_traces,
plot_metric,
"prefill " + plot_title_suffix,
output_directory / Path("prefill.png"),
)
plot_trace_df(
decode_traces,
plot_metric,
"decodes " + plot_title_suffix,
output_directory / Path("decode_steps.png"),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--json-trace",
type=str,
required=True,
help="json trace file output by \
examples/offline_inference/profiling.py")
parser.add_argument("--output-directory",
type=str,
required=False,
help="Directory to output plots")
parser.add_argument("--level",
type=str,
default="module",
choices=["module", "kernel"])
parser.add_argument("--top-k",
type=int,
default=12,
help="Only graph the top `top_k` entries by time.")
parser.add_argument("--fold-json-node",
nargs='+',
default=['Sampler', 'LogitsProcessor'],
help='Do not plot the children of these nodes. Let, \
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by \
examples/offline_inference/profiling.py",
)
parser.add_argument(
"--output-directory", type=str, required=False, help="Directory to output plots"
)
parser.add_argument(
"--level", type=str, default="module", choices=["module", "kernel"]
)
parser.add_argument(
"--top-k",
type=int,
default=12,
help="Only graph the top `top_k` entries by time.",
)
parser.add_argument(
"--fold-json-node",
nargs="+",
default=["Sampler", "LogitsProcessor"],
help="Do not plot the children of these nodes. Let, \
the node represent the aggregate of all its \
children')
parser.add_argument("--plot-metric",
type=str,
default="cuda_time_ms",
help='Metric to plot. some options are cuda_time_ms, \
pct_cuda_time')
children",
)
parser.add_argument(
"--plot-metric",
type=str,
default="cuda_time_ms",
help="Metric to plot. some options are cuda_time_ms, \
pct_cuda_time",
)
parser.add_argument(
"--step-plot-interval",
type=int,
default=4,
help="For every `step_plot_interval` steps, plot 1 step")
help="For every `step_plot_interval` steps, plot 1 step",
)
args = parser.parse_args()
@@ -583,11 +613,19 @@ if __name__ == "__main__":
else:
raise Exception(f"Unexpected level value ({args.level})")
output_directory = args.output_directory if args.output_directory else Path(
args.json_trace).parent
output_directory = (
args.output_directory if args.output_directory else Path(args.json_trace).parent
)
if not os.path.exists(output_directory):
os.makedirs(output_directory)
main(Path(args.json_trace), output_directory, depth, args.plot_metric,
make_names_unique, args.top_k, args.fold_json_node)
main(
Path(args.json_trace),
output_directory,
depth,
args.plot_metric,
make_names_unique,
args.top_k,
args.fold_json_node,
)