diff --git a/vllm/benchmarks/sweep/plot.py b/vllm/benchmarks/sweep/plot.py index 376adbb08..87323757e 100644 --- a/vllm/benchmarks/sweep/plot.py +++ b/vllm/benchmarks/sweep/plot.py @@ -346,7 +346,45 @@ def _plot_fig( else "(All)" ) - g = sns.FacetGrid(df, row="row_group", col="col_group", height=fig_height) + if len(curve_by) <= 3: + hue, style, size, *_ = (*curve_by, None, None, None) + + g = sns.relplot( + df, + x=var_x, + y=var_y, + hue=hue, + style=style, + size=size, + markers=True, + errorbar="sd" if error_bars else None, + kind="line", + row="row_group", + col="col_group", + height=fig_height, + ) + else: + df["curve_group"] = ( + pd.concat( + [k + "=" + df[k].astype(str) for k in curve_by], + axis=1, + ).agg("\n".join, axis=1) + if curve_by + else "(All)" + ) + + g = sns.relplot( + df, + x=var_x, + y=var_y, + hue="curve_group", + markers=True, + errorbar="sd" if error_bars else None, + kind="line", + row="row_group", + col="col_group", + height=fig_height, + ) if row_by and col_by: g.set_titles("{row_name}\n{col_name}") @@ -362,42 +400,6 @@ def _plot_fig( if scale_y: g.set(yscale=scale_y) - if len(curve_by) <= 3: - hue, style, size, *_ = (*curve_by, None, None, None) - - g.map_dataframe( - sns.lineplot, - x=var_x, - y=var_y, - hue=hue, - style=style, - size=size, - markers=True, - errorbar="sd" if error_bars else None, - ) - - g.add_legend(title=hue) - else: - df["curve_group"] = ( - pd.concat( - [k + "=" + df[k].astype(str) for k in curve_by], - axis=1, - ).agg("\n".join, axis=1) - if curve_by - else "(All)" - ) - - g.map_dataframe( - sns.lineplot, - x=var_x, - y=var_y, - hue="curve_group", - markers=True, - errorbar="sd" if error_bars else None, - ) - - g.add_legend() - g.savefig(fig_path, dpi=fig_dpi) plt.close(g.figure)