Source code for tab_right.task_detection
"""Task detection utilities for tab-right package."""
from enum import Enum
import pandas as pd
[docs]
class TaskType(Enum):
"""Enumeration of possible task types for model evaluation."""
BINARY = "binary"
CLASS = "class"
REG = "reg"
[docs]
def detect_task(y: pd.Series) -> "TaskType":
"""Detect the type of task (binary, class, regression) based on the label series y.
Args:
y (pd.Series): The label series to analyze.
Returns:
TaskType: The detected task type.
Raises:
ValueError: If the label column has only one unique value and the task cannot be inferred.
"""
unique = set(y.dropna().unique())
n_classes = len(unique)
if n_classes == 1:
raise ValueError("Label column has only one unique value; cannot infer task.")
# If float dtype, always regression
if pd.api.types.is_float_dtype(y):
return TaskType.REG
# Deprecated: Use isinstance instead of is_categorical_dtype
if isinstance(y.dtype, pd.CategoricalDtype) or y.dtype == object:
if n_classes == 2:
return TaskType.BINARY
else:
return TaskType.CLASS
if n_classes == 2:
return TaskType.BINARY
elif n_classes <= 10:
return TaskType.CLASS
else:
return TaskType.REG