[misc] Layerwise profile updates (#10242)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
2ca830dbaa
commit
efbce85f4d
@@ -34,9 +34,10 @@ if __name__ == "__main__":
|
||||
"examples/offline_profile.py")
|
||||
parser.add_argument("--phase",
|
||||
type=str,
|
||||
choices=["prefill", "decode_1"],
|
||||
required=True,
|
||||
help="The phase to print the table for.")
|
||||
help="The phase to print the table for. This is either"
|
||||
"prefill or decode_n, where n is the decode step "
|
||||
"number")
|
||||
parser.add_argument("--table",
|
||||
type=str,
|
||||
choices=["summary", "model"],
|
||||
@@ -49,6 +50,10 @@ if __name__ == "__main__":
|
||||
with open(args.json_trace) as f:
|
||||
profile_data = json.load(f)
|
||||
|
||||
assert args.phase in profile_data, \
|
||||
(f"Cannot find phase {args.phase} in profile data. Choose one among"
|
||||
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa
|
||||
|
||||
if args.table == "summary":
|
||||
entries_and_depths = flatten_entries(
|
||||
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
|
||||
|
||||
@@ -151,16 +151,31 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"scaled_int8_quant" in op_name:
|
||||
return True
|
||||
|
||||
# LoRA ops
|
||||
def is_sgmv_shrink(op_name: str):
|
||||
return "sgmv_shrink" in op_name
|
||||
|
||||
def is_sgmv_expand(op_name: str):
|
||||
return "sgmv_expand" in op_name
|
||||
|
||||
def is_bgmv_shrink(op_name: str):
|
||||
return "bgmv_shrink" in op_name
|
||||
|
||||
def is_bgmv_expand(op_name: str):
|
||||
return "bgmv_expand" in op_name
|
||||
|
||||
def is_cutlass_gemm_op(op_name: str):
|
||||
return "void cutlass::Kernel" in op_name or \
|
||||
"void cutlass::device_kernel" in op_name
|
||||
|
||||
def is_gemm_op(op_name: str):
|
||||
if is_quant(op_name):
|
||||
return False
|
||||
if "xmma_gemm" in op_name or \
|
||||
return is_cutlass_gemm_op(op_name) or \
|
||||
"xmma_gemm" in op_name or \
|
||||
"gemv2T_kernel" in op_name or \
|
||||
"splitKreduce" in op_name or \
|
||||
"void cutlass::Kernel" in op_name or \
|
||||
"void cutlass::device_kernel" in op_name or \
|
||||
"s16816gemm" in op_name:
|
||||
return True
|
||||
"s16816gemm" in op_name
|
||||
|
||||
def is_elementwise_op(op_name: str):
|
||||
return "elementwise_kernel" in op_name
|
||||
@@ -211,6 +226,18 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
quant_ops = list(filter(lambda x: is_quant(x), ops))
|
||||
ops = list(filter(lambda x: x not in quant_ops, ops))
|
||||
|
||||
sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops))
|
||||
ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops))
|
||||
sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops))
|
||||
ops = list(filter(lambda x: x not in sgmv_expand_ops, ops))
|
||||
bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops))
|
||||
ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops))
|
||||
bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops))
|
||||
ops = list(filter(lambda x: x not in bgmv_expand_ops, ops))
|
||||
|
||||
cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops))
|
||||
|
||||
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
|
||||
ops = list(filter(lambda x: x not in gemm_ops, ops))
|
||||
|
||||
@@ -257,6 +284,24 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
|
||||
if len(quant_ops):
|
||||
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
|
||||
|
||||
if len(sgmv_shrink_ops):
|
||||
trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum",
|
||||
axis=1)
|
||||
if len(sgmv_expand_ops):
|
||||
trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum",
|
||||
axis=1)
|
||||
if len(bgmv_shrink_ops):
|
||||
trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum",
|
||||
axis=1)
|
||||
if len(bgmv_expand_ops):
|
||||
trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum",
|
||||
axis=1)
|
||||
|
||||
if len(cutlass_gemm_ops):
|
||||
trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum",
|
||||
axis=1)
|
||||
|
||||
if len(gemm_ops):
|
||||
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
|
||||
if len(rms_norm_ops):
|
||||
@@ -296,7 +341,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
|
||||
axis=1)
|
||||
|
||||
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
|
||||
trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops +
|
||||
sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops +
|
||||
cutlass_gemm_ops + gemm_ops + rms_norm_ops +
|
||||
vocab_embed_ops + mem_ops + elementwise_ops +
|
||||
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
|
||||
nccl_other_ops + cross_device_reduce_1stage_ops +
|
||||
@@ -315,7 +362,14 @@ def plot_trace_df(traces_df: pd.DataFrame,
|
||||
plot_title: str,
|
||||
output: Optional[Path] = None):
|
||||
|
||||
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
|
||||
phase_df = traces_df.query(f'phase == "{phase}"')
|
||||
descs = phase_df['phase_desc'].to_list()
|
||||
assert all([desc == descs[0] for desc in descs])
|
||||
return descs[0]
|
||||
|
||||
phases = traces_df['phase'].unique()
|
||||
phase_descs = [get_phase_description(traces_df, p) for p in phases]
|
||||
traces_df = traces_df.pivot_table(index="phase",
|
||||
columns="name",
|
||||
values=plot_metric,
|
||||
@@ -324,7 +378,8 @@ def plot_trace_df(traces_df: pd.DataFrame,
|
||||
traces_df = group_trace_by_operations(traces_df)
|
||||
|
||||
# Make the figure
|
||||
fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True)
|
||||
fig_size_x = max(5, len(phases))
|
||||
fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True)
|
||||
|
||||
# Draw the stacked bars
|
||||
ops = list(traces_df)
|
||||
@@ -332,7 +387,7 @@ def plot_trace_df(traces_df: pd.DataFrame,
|
||||
for op in ops:
|
||||
values = [traces_df[op][phase] for phase in phases]
|
||||
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
|
||||
ax.bar(phases, values, label=op, bottom=bottom)
|
||||
ax.bar(phase_descs, values, label=op, bottom=bottom)
|
||||
bottom = [bottom[j] + values[j] for j in range(len(phases))]
|
||||
|
||||
# Write the values as text on the bars
|
||||
@@ -390,6 +445,14 @@ def main(
|
||||
["name"]] = "others"
|
||||
return df
|
||||
|
||||
def get_phase_description(key: str) -> str:
|
||||
num_running_seqs = profile_json[key]['metadata'][
|
||||
'num_running_seqs']
|
||||
if num_running_seqs is not None:
|
||||
return f"{key}-seqs-{num_running_seqs}"
|
||||
else:
|
||||
return key
|
||||
|
||||
# Get data for each key
|
||||
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
|
||||
|
||||
@@ -413,6 +476,7 @@ def main(
|
||||
# Fill in information about the step-keys
|
||||
for trace_df, step_key in zip(trace_dfs, step_keys):
|
||||
trace_df['phase'] = step_key
|
||||
trace_df['phase_desc'] = get_phase_description(step_key)
|
||||
|
||||
# Combine all data frames so they can be put in a single plot
|
||||
traces_df = pd.concat(trace_dfs)
|
||||
@@ -426,12 +490,16 @@ def main(
|
||||
def make_plot_title_suffix(profile_json: dict) -> str:
|
||||
context = profile_json["context"]
|
||||
sparsity = context.get('sparsity', None)
|
||||
return (f"{context['model']}\n"
|
||||
run_type = \
|
||||
f'Run {context["num_steps"]} steps' if context['num_steps'] else \
|
||||
(f'Complete {context["complete_num_requests_per_step"]} per '
|
||||
f'step; Run till completion')
|
||||
return (f"{context['engine_args']['model']}\n"
|
||||
f"Batch={context['batch_size']}, "
|
||||
f"PromptLen={context['prompt_len']}, "
|
||||
f"OutputLen={context['output_len']},"
|
||||
f"NumGpus={context['tensor_parallel_size']}"
|
||||
f"{', Sparsity ' + sparsity if sparsity else ''}")
|
||||
f"NumGpus={context['engine_args']['tensor_parallel_size']}"
|
||||
f"{', Sparsity ' + sparsity if sparsity else ''}\n"
|
||||
f"Run Type: {run_type}")
|
||||
|
||||
profile_json = None
|
||||
with open(json_trace) as f:
|
||||
|
||||
Reference in New Issue
Block a user