Source code for tab_right.plotting.plot_segmentations

"""Module for plotting segmentation results from decision tree models.

This module provides functions for visualizing the segmentation results
generated by decision tree models, helping to interpret the model's decisions.

Common parameters across functions and classes:
-----------------------------------------------
df : pd.DataFrame
    A DataFrame containing the groups defined by the segmentation.
    For single segmentation:
        - `segment_id`: The ID of the segment, for grouping.
        - `segment_name`: (str) the range or category of the feature.
        - `score`: (float) The calculated error metric for the segment.
    For double segmentation:
        - `segment_id`: The ID of the segment, for grouping.
        - `feature_1`: (str) the range or category of the first feature.
        - `feature_2`: (str) the range or category of the second feature.
        - `score`: (float) The calculated error metric for the segment.
metric_name : str, default="score"
    The name of the metric column in the DataFrame.
lower_is_better : bool, default=True
    Whether lower values of the metric indicate better performance.
    Affects the color scale in visualizations (green for better, red for worse).
backend : str, default="plotly"
    The plotting backend to use. Either "plotly" or "matplotlib".
"""

from dataclasses import dataclass
from typing import Dict, Literal, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from matplotlib.figure import Figure as MatplotlibFigure
from plotly.graph_objects import Figure as PlotlyFigure

from tab_right.base_architecture.seg_plotting_protocols import Figure

# Type definitions
Backend = Literal["plotly", "matplotlib"]
ColorMap = Union[str, list]


def _prepare_data(df: pd.DataFrame) -> pd.DataFrame:
    """Prepare data for segmentation plotting by sorting.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame to prepare

    Returns
    -------
    pd.DataFrame
        Sorted DataFrame

    """
    # Sort by segment_id to ensure consistent ordering
    return df.sort_values(by="segment_id")


def _get_color_scheme(lower_is_better: bool = True, backend: Backend = "plotly") -> Dict[str, Union[str, list]]:
    """Get the appropriate color scheme based on the backend and lower_is_better parameter.

    Parameters
    ----------
    lower_is_better : bool, default=True
        Whether lower values of the metric indicate better performance.
    backend : str, default="plotly"
        The plotting backend to use.

    Returns
    -------
    dict
        Dictionary with color scheme information for the specified backend

    """
    if backend == "plotly":
        if lower_is_better:
            return {"colorscale": [[0, "green"], [0.5, "yellow"], [1, "red"]]}
        return {"colorscale": [[0, "red"], [0.5, "yellow"], [1, "green"]]}
    else:  # matplotlib
        if lower_is_better:
            return {"cmap": "RdYlGn_r"}  # Red (high/bad) to Green (low/good)
        return {"cmap": "RdYlGn"}  # Red (low/bad) to Green (high/good)


[docs] def plot_single_segmentation(df: pd.DataFrame, lower_is_better: bool = True, backend: Backend = "plotly") -> Figure: """Plot the single segmentation of a given DataFrame as a bar chart. This function can use either Plotly or Matplotlib as backend. Parameters ---------- df : pd.DataFrame See module docstring for format details. lower_is_better : bool, default=True Whether lower values of the metric indicate better performance. backend : str, default="plotly" The plotting backend to use. Either "plotly" or "matplotlib". Returns ------- Figure A bar chart showing each segment with its corresponding avg score. """ if backend == "plotly": return _plot_single_segmentation_plotly(df, lower_is_better) else: return _plot_single_segmentation_matplotlib(df, lower_is_better)
def _plot_single_segmentation_plotly(df: pd.DataFrame, lower_is_better: bool = True) -> PlotlyFigure: """Implement the single segmentation plot as a Plotly bar chart. Parameters ---------- df : pd.DataFrame See module docstring for format details. lower_is_better : bool, default=True Whether lower values of the metric indicate better performance. Returns ------- PlotlyFigure A Plotly bar chart. """ # Prepare data df_sorted = _prepare_data(df) # Get color scheme color_scheme = _get_color_scheme(lower_is_better, "plotly") # Create a bar chart fig = go.Figure( data=[ go.Bar( x=df_sorted["segment_name"].astype(str), y=df_sorted["score"], marker=dict( color=df_sorted["score"], colorscale=color_scheme["colorscale"], colorbar=dict(title="Score"), ), text=df_sorted["score"].round(3), textposition="auto", ) ] ) # Customize layout fig.update_layout( title="Segmentation Analysis by Feature", xaxis_title="Feature Segments", yaxis_title="Error Score", template="plotly_white", coloraxis_showscale=True, ) return fig def _plot_single_segmentation_matplotlib(df: pd.DataFrame, lower_is_better: bool = True) -> MatplotlibFigure: """Implement the single segmentation plot as a Matplotlib bar chart. Parameters ---------- df : pd.DataFrame See module docstring for format details. lower_is_better : bool, default=True Whether lower values of the metric indicate better performance. Returns ------- MatplotlibFigure A Matplotlib bar chart. """ # Prepare data df_sorted = _prepare_data(df) # Create matplotlib figure fig, ax = plt.subplots(figsize=(10, 6)) # Get color scheme color_scheme = _get_color_scheme(lower_is_better, "matplotlib") cmap_name = color_scheme["cmap"] assert isinstance(cmap_name, str), "matplotlib cmap should be a string" # Normalize the scores for colormapping if len(df_sorted) > 1: norm = plt.Normalize(float(df_sorted["score"].min()), float(df_sorted["score"].max())) else: norm = plt.Normalize(0, 1) cmap = plt.get_cmap(cmap_name) colors = cmap(norm(df_sorted["score"].values.astype(np.float64))) # Create bar chart bars = ax.bar(df_sorted["segment_name"].astype(str), df_sorted["score"], color=colors) # Add value labels on top of bars for bar in bars: height = bar.get_height() ax.text( bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.3f}", ha="center", va="bottom", fontsize=9 ) # Create colorbar sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax) cbar.set_label("Score") # Customize plot ax.set_title("Segmentation Analysis by Feature") ax.set_xlabel("Feature Segments") ax.set_ylabel("Error Score") plt.xticks(rotation=45, ha="right") plt.tight_layout() return fig # For backward compatibility
[docs] def plot_single_segmentation_mp(df: pd.DataFrame, lower_is_better: bool = True) -> MatplotlibFigure: """Plot the single segmentation using matplotlib (compatibility function). This is a wrapper around plot_single_segmentation with backend="matplotlib" for backwards compatibility. Parameters ---------- df : pd.DataFrame See module docstring for format details. lower_is_better : bool, default=True Whether lower values indicate better performance. Returns ------- MatplotlibFigure A matplotlib bar chart showing each segment with its corresponding score. """ return plot_single_segmentation(df, lower_is_better, backend="matplotlib")
# For backward compatibility
[docs] def plot_single_segmentation_impl(df: pd.DataFrame, lower_is_better: bool = True) -> PlotlyFigure: """Implement the single segmentation plot as a Plotly bar chart (compatibility function). This is kept for backwards compatibility and wraps _plot_single_segmentation_plotly. Parameters ---------- df : pd.DataFrame See module docstring for format details. lower_is_better : bool, default=True Whether lower values indicate better performance. Returns ------- PlotlyFigure A Plotly bar chart. """ return _plot_single_segmentation_plotly(df, lower_is_better)
[docs] @dataclass class DoubleSegmPlotting: """Class for double segmentation plotting with support for multiple backends. This class implements the interface for plotting double segmentations. It includes the DataFrames to be plotted and supports multiple plotting backends. See the module docstring for parameter details. """ df: pd.DataFrame metric_name: str = "score" lower_is_better: bool = True backend: Backend = "plotly"
[docs] def get_heatmap_df(self) -> pd.DataFrame: """Get the DataFrame for the heatmap from the double segmentation df. Returns ------- pd.DataFrame A DataFrame containing the groups defined by the decision tree model. columns: feature_1 ranges or categories index: feature_2 ranges or categories content: The calculated error metric for the segment. """ # Pivot the dataframe to create a heatmap-ready format pivot_df = self.df.pivot(index="feature_2", columns="feature_1", values=self.metric_name) return pivot_df
def _plot_heatmap_plotly(self) -> PlotlyFigure: """Plot the double segmentation as a heatmap using Plotly. Returns ------- PlotlyFigure A Plotly heatmap showing each segment with its corresponding avg score. """ heatmap_df = self.get_heatmap_df() # Get color scheme color_scheme = _get_color_scheme(self.lower_is_better, "plotly") # Create heatmap fig = go.Figure( data=go.Heatmap( z=heatmap_df.values, x=heatmap_df.columns, y=heatmap_df.index, colorscale=color_scheme["colorscale"], text=heatmap_df.round(3).values, texttemplate="%{text}", colorbar=dict(title=self.metric_name), ) ) # Customize layout fig.update_layout( title="Double Segmentation Heatmap", xaxis_title="Feature 1", yaxis_title="Feature 2", template="plotly_white", ) return fig def _plot_heatmap_matplotlib(self) -> MatplotlibFigure: """Plot the double segmentation as a heatmap using Matplotlib. Returns ------- MatplotlibFigure A Matplotlib heatmap showing each segment with its corresponding avg score. """ # Set non-interactive backend to avoid Tkinter issues import matplotlib matplotlib.use("Agg") heatmap_df = self.get_heatmap_df() # Create figure and axes fig, ax = plt.subplots(figsize=(10, 8)) # Get color scheme color_scheme = _get_color_scheme(self.lower_is_better, "matplotlib") cmap = color_scheme["cmap"] assert isinstance(cmap, str), "matplotlib cmap should be a string" # Create heatmap using pcolormesh which creates a QuadMesh collection # First create a meshgrid for the x and y coordinates x = np.arange(len(heatmap_df.columns) + 1) y = np.arange(len(heatmap_df.index) + 1) # Create the heatmap using pcolormesh mesh = ax.pcolormesh(x, y, heatmap_df.values, cmap=cmap) # Set x and y labels ax.set_xticks(np.arange(len(heatmap_df.columns)) + 0.5) ax.set_yticks(np.arange(len(heatmap_df.index)) + 0.5) ax.set_xticklabels(heatmap_df.columns) ax.set_yticklabels(heatmap_df.index) plt.setp(ax.get_xticklabels(), rotation=45, ha="right") # Add colorbar cbar = fig.colorbar(mesh, ax=ax) cbar.set_label(self.metric_name) # Add text annotations with the values for i in range(len(heatmap_df.index)): for j in range(len(heatmap_df.columns)): value = heatmap_df.values[i, j] if not pd.isna(value): text_color = "black" if 0.3 < value < 0.7 else "white" ax.text(j + 0.5, i + 0.5, f"{value:.3f}", ha="center", va="center", color=text_color) # Set titles ax.set_title("Double Segmentation Heatmap") ax.set_xlabel("Feature 1") ax.set_ylabel("Feature 2") # Adjust layout plt.tight_layout() return fig
[docs] def plot_heatmap(self) -> Figure: """Plot the double segmentation of a given DataFrame as a heatmap. Returns ------- Figure A heatmap showing each segment with its corresponding avg score. The backend (matplotlib or plotly) is determined by the backend parameter. """ if self.backend == "matplotlib": return self._plot_heatmap_matplotlib() else: # Default to plotly return self._plot_heatmap_plotly()