Source code for tab_right.segmentations.calc_seg
"""Module for calculating segmentation metrics."""
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Union # Standard library
import pandas as pd # Third-party
from pandas.api.typing import DataFrameGroupBy # Third-party
[docs]
@dataclass
class SegmentationCalc:
"""Implementation of BaseSegmentationCalc protocol.
Calculates scores for pre-defined segments.
Attributes
----------
gdf : DataFrameGroupBy
Grouped DataFrame, each group represents a segment (grouped by segment_id).
label_col : str
Column name for the true target values.
prediction_col : Union[str, List[str]]
Column names for the predicted values.
segment_names : Dict[int, Any]
Mapping from segment_id to the original group name (category or interval).
"""
gdf: DataFrameGroupBy
label_col: str
prediction_col: Union[str, List[str]]
segment_names: Dict[int, Any]
def _reduce_metric_results(
self,
results: Union[float, pd.Series],
) -> float:
"""Reduce the metric results to a single value if the metric produces a series.
If it produces a single value, return it. Used for getting a single value for each segment.
Parameters
----------
results : Union[float, pd.Series]
The metric results to reduce.
Returns
-------
float
The reduced metric result.
"""
if isinstance(results, pd.Series):
return float(results.mean())
return float(results)
def __call__(self, metric: Callable) -> pd.DataFrame:
"""Apply the metric to each group and return scores with segment names.
Ensures all segments defined in `segment_names` are included in the output,
assigning NaN to segments with no data.
Parameters
----------
metric : Callable
Metric function to apply.
Returns
-------
pd.DataFrame
DataFrame with segment_id, name, and score.
"""
# Initialize results with all segment IDs from segment_names, default score NaN
results = {segment_id: float("nan") for segment_id in self.segment_names}
# Calculate scores for segments present in the grouped data
for name, group in self.gdf:
segment_id = int(name) # Ensure name is treated as the segment_id (integer)
if segment_id in results: # Process only if the segment is expected
if not group.empty:
y_true = group[self.label_col]
y_pred = group[self.prediction_col]
score = metric(y_true, y_pred)
results[segment_id] = self._reduce_metric_results(score)
# If group is empty but segment_id is in results, it keeps the NaN score
# Convert results dictionary to DataFrame
scores_df = pd.DataFrame(list(results.items()), columns=["segment_id", "score"])
# Add the segment names using the stored mapping
scores_df["name"] = scores_df["segment_id"].map(self.segment_names)
# Convert interval names to strings for consistent output
scores_df["name"] = scores_df["name"].apply(lambda x: str(x) if isinstance(x, pd.Interval) else x)
# Reorder columns and ensure correct order even if some segments were empty
scores_df = scores_df.sort_values("segment_id").reset_index(drop=True)
return scores_df[["segment_id", "name", "score"]]