Module music_df.plot_piano_rolls.plot_helper

Functions

def plot_predictions(music_df: pandas.DataFrame,
feature_name: str,
predicted_feature: Sequence[Any],
prediction_indices: list[int],
plot_predicted_notes_only: bool = True,
transpose: int | None = None,
colormapping=None,
label_notes: bool = True,
ax: matplotlib.axes._axes.Axes | None = None,
title: str | None = None)
Expand source code
def plot_predictions(
    music_df: pd.DataFrame,
    feature_name: str,
    predicted_feature: Sequence[Any],
    prediction_indices: list[int],
    plot_predicted_notes_only: bool = True,
    transpose: int | None = None,
    colormapping=None,
    label_notes: bool = True,
    ax: plt.Axes | None = None,
    title: str | None = None,
):
    """
    Args:
        music_df: dataframe
        predicted_feature: sequence of predictions
        prediction_indices: a sequence of features mapping predictions to notes in the
            music_df
    """

    if plot_predicted_notes_only:
        # TODO: (Malcolm 2023-09-25) maybe we want to keep non-note events
        music_df = music_df.filter(items=prediction_indices, axis=0)

    if transpose is not None:
        music_df.pitch += transpose

    music_df[f"pred_{feature_name}"] = None
    for pred, i in zip(predicted_feature, prediction_indices):
        music_df.loc[i, f"pred_{feature_name}"] = pred

    if colormapping is None:
        colormapping = get_colormapping(music_df[feature_name])

    music_df["correct"] = music_df[f"pred_{feature_name}"] == music_df[feature_name]
    music_df["colors"] = [colormapping[x] for x in music_df[feature_name]]
    if label_notes:
        music_df["label_colors"] = music_df["correct"].replace(
            {True: "black", False: "red"}
        )
    music_df = music_df[music_df.type == "note"]

    plot_piano_roll(
        music_df,
        colors=music_df["colors"].to_list(),
        labels=music_df[f"pred_{feature_name}"].to_list(),
        label_colors=music_df["label_colors"].to_list(),
        ax=ax,
        title=title,
    )

Args

music_df
dataframe
predicted_feature
sequence of predictions
prediction_indices
a sequence of features mapping predictions to notes in the music_df