[misc] Layerwise profile updates (#10242)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2024-12-16 13:14:57 -05:00
committed by GitHub
parent 2ca830dbaa
commit efbce85f4d
5 changed files with 314 additions and 47 deletions

View File

@@ -34,9 +34,10 @@ if __name__ == "__main__":
"examples/offline_profile.py")
parser.add_argument("--phase",
type=str,
choices=["prefill", "decode_1"],
required=True,
help="The phase to print the table for.")
help="The phase to print the table for. This is either"
"prefill or decode_n, where n is the decode step "
"number")
parser.add_argument("--table",
type=str,
choices=["summary", "model"],
@@ -49,6 +50,10 @@ if __name__ == "__main__":
with open(args.json_trace) as f:
profile_data = json.load(f)
assert args.phase in profile_data, \
(f"Cannot find phase {args.phase} in profile data. Choose one among"
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa
if args.table == "summary":
entries_and_depths = flatten_entries(
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])

View File

@@ -151,16 +151,31 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
"scaled_int8_quant" in op_name:
return True
# LoRA ops
def is_sgmv_shrink(op_name: str):
return "sgmv_shrink" in op_name
def is_sgmv_expand(op_name: str):
return "sgmv_expand" in op_name
def is_bgmv_shrink(op_name: str):
return "bgmv_shrink" in op_name
def is_bgmv_expand(op_name: str):
return "bgmv_expand" in op_name
def is_cutlass_gemm_op(op_name: str):
return "void cutlass::Kernel" in op_name or \
"void cutlass::device_kernel" in op_name
def is_gemm_op(op_name: str):
if is_quant(op_name):
return False
if "xmma_gemm" in op_name or \
return is_cutlass_gemm_op(op_name) or \
"xmma_gemm" in op_name or \
"gemv2T_kernel" in op_name or \
"splitKreduce" in op_name or \
"void cutlass::Kernel" in op_name or \
"void cutlass::device_kernel" in op_name or \
"s16816gemm" in op_name:
return True
"s16816gemm" in op_name
def is_elementwise_op(op_name: str):
return "elementwise_kernel" in op_name
@@ -211,6 +226,18 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
quant_ops = list(filter(lambda x: is_quant(x), ops))
ops = list(filter(lambda x: x not in quant_ops, ops))
sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops))
ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops))
sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops))
ops = list(filter(lambda x: x not in sgmv_expand_ops, ops))
bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops))
ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops))
bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops))
ops = list(filter(lambda x: x not in bgmv_expand_ops, ops))
cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops))
ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops))
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
ops = list(filter(lambda x: x not in gemm_ops, ops))
@@ -257,6 +284,24 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops):
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
if len(sgmv_shrink_ops):
trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum",
axis=1)
if len(sgmv_expand_ops):
trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum",
axis=1)
if len(bgmv_shrink_ops):
trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum",
axis=1)
if len(bgmv_expand_ops):
trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum",
axis=1)
if len(cutlass_gemm_ops):
trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum",
axis=1)
if len(gemm_ops):
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
if len(rms_norm_ops):
@@ -296,7 +341,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops +
sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops +
cutlass_gemm_ops + gemm_ops + rms_norm_ops +
vocab_embed_ops + mem_ops + elementwise_ops +
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
nccl_other_ops + cross_device_reduce_1stage_ops +
@@ -315,7 +362,14 @@ def plot_trace_df(traces_df: pd.DataFrame,
plot_title: str,
output: Optional[Path] = None):
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
phase_df = traces_df.query(f'phase == "{phase}"')
descs = phase_df['phase_desc'].to_list()
assert all([desc == descs[0] for desc in descs])
return descs[0]
phases = traces_df['phase'].unique()
phase_descs = [get_phase_description(traces_df, p) for p in phases]
traces_df = traces_df.pivot_table(index="phase",
columns="name",
values=plot_metric,
@@ -324,7 +378,8 @@ def plot_trace_df(traces_df: pd.DataFrame,
traces_df = group_trace_by_operations(traces_df)
# Make the figure
fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True)
fig_size_x = max(5, len(phases))
fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True)
# Draw the stacked bars
ops = list(traces_df)
@@ -332,7 +387,7 @@ def plot_trace_df(traces_df: pd.DataFrame,
for op in ops:
values = [traces_df[op][phase] for phase in phases]
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
ax.bar(phases, values, label=op, bottom=bottom)
ax.bar(phase_descs, values, label=op, bottom=bottom)
bottom = [bottom[j] + values[j] for j in range(len(phases))]
# Write the values as text on the bars
@@ -390,6 +445,14 @@ def main(
["name"]] = "others"
return df
def get_phase_description(key: str) -> str:
num_running_seqs = profile_json[key]['metadata'][
'num_running_seqs']
if num_running_seqs is not None:
return f"{key}-seqs-{num_running_seqs}"
else:
return key
# Get data for each key
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
@@ -413,6 +476,7 @@ def main(
# Fill in information about the step-keys
for trace_df, step_key in zip(trace_dfs, step_keys):
trace_df['phase'] = step_key
trace_df['phase_desc'] = get_phase_description(step_key)
# Combine all data frames so they can be put in a single plot
traces_df = pd.concat(trace_dfs)
@@ -426,12 +490,16 @@ def main(
def make_plot_title_suffix(profile_json: dict) -> str:
context = profile_json["context"]
sparsity = context.get('sparsity', None)
return (f"{context['model']}\n"
run_type = \
f'Run {context["num_steps"]} steps' if context['num_steps'] else \
(f'Complete {context["complete_num_requests_per_step"]} per '
f'step; Run till completion')
return (f"{context['engine_args']['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"OutputLen={context['output_len']},"
f"NumGpus={context['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}")
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
f"Run Type: {run_type}")
profile_json = None
with open(json_trace) as f: