fairlearn.metrics.plot_model_comparison#

fairlearn.metrics.plot_model_comparison(*, y_preds, y_true=None, sensitive_features=None, x_axis_metric=None, y_axis_metric=None, ax=None, axis_labels=True, point_labels=False, point_labels_position=(0, 0), legend=False, show_plot=False, **kwargs)[source]#

Create a scatter plot comparing multiple models along two metrics.

A typical use case is when one of the metrics is a performance metric (e.g., balanced_accuracy) and the other is a fairness metric (e.g., false_negative_rate_difference).

Parameters
  • y_preds (array-like, dict of array-like) – An array-like containing predictions per model. Hence, predictions of model i should be in y_preds[i].

  • y_true (List, pandas.Series, numpy.ndarray, pandas.DataFrame) – The ground-truth labels (for classification) or target values (for regression).

  • sensitive_features (List, pandas.Series, dict of 1d arrays, numpy.ndarray, pandas.DataFrame, optional) – Sensitive features for the fairness metrics (if a fairness metric is specified for the x-axis or the y-axis).

  • x_axis_metric (Callable) – The metric function for the x-axis. The metric function must take y_true, y_pred, and optionally sensitive_features as arguments, and return a scalar value.

  • y_axis_metric (Callable) – The metric function for the y-axis, similar to x_axis_metric. The metric function must take y_true, y_pred, and optionally sensitive_features as arguments, and return a scalar value.

  • ax (matplotlib.axes.Axes, optional) – If supplied, the scatter plot is drawn on this Axes object. Else, a new figure with Axes is created.

  • axis_labels (bool, list) – If true, add the names of x and y axis metrics. You can also pass a list of size two (or a two-tuple) of strings to use as axis labels instead.

  • point_labels (bool, list) – If true, annotate text with inferred point labels. These labels are the keys of y_preds if y_preds is a dictionary, else simply the integers 0…number of points - 1. You can specify point_labels as a list of labels as well.

  • point_labels_position (list) – a list (or a two-tuple) containing precisely two numbers that define the offset of the point labels in the x and y direction respectively. The offset value is in data coordinates, not pixels.

  • legend (bool) – If True, add a legend. Legend entries are created by passing the key word argument label to calls to this function. If you want to customize the legend, you should manually call ax.legend (where ax is the Axes object) with your customization params

  • show_plot (bool) – If true, call pyplot.plot.

Returns

ax – The Axes object that was drawn on.

Return type

matplotlib.axes.Axes

Notes

To offer flexibility in stylistic features besides the aforementioned API options, one has at least three options: 1) supply matplotlib arguments to plot_model_comparison as you normally would to matplotlib.axes.Axes.scatter 2) change the style of the returned Axes 3) supply an Axes with your own style already applied

In case no Axes object is supplied, axis labels are automatically inferred from their class name.