import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tab_right.drift.drift_calculator import DriftCalculator
from tab_right.plotting.drift_plotter import DriftPlotter

# Create datasets with increasing levels of drift
np.random.seed(42)
ref_data = np.random.normal(0, 1, 500)

# Create three datasets with different levels of drift
slight_drift = np.random.normal(0.2, 1.1, 500)  # slight drift
moderate_drift = np.random.normal(0.5, 1.3, 500)  # moderate drift
severe_drift = np.random.normal(2.0, 1.8, 500)  # severe drift

# Create a figure with 3 subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Set up titles
titles = ['Slight Drift', 'Moderate Drift', 'Severe Drift']
drift_data = [slight_drift, moderate_drift, severe_drift]

# Create and plot each dataset using tab_right
for i, current_data in enumerate(drift_data):
    # Create DataFrames
    df_ref = pd.DataFrame({'value': ref_data})
    df_cur = pd.DataFrame({'value': current_data})

    # Calculate drift
    drift_calc = DriftCalculator(df_ref, df_cur)
    drift_result = drift_calc()
    drift_score = round(drift_result.iloc[0]['score'], 3)

    # Create plotter
    plotter = DriftPlotter(drift_calc)

    # Plot distribution on the corresponding subplot
    dist_fig = plotter.plot_single('value')

    # Remove the original figure and copy its content to our subplot
    for line in dist_fig.axes[0].lines:
        axes[i].plot(line.get_xdata(), line.get_ydata(),
                     color=line.get_color(), label=line.get_label())

    # Set title with drift score
    axes[i].set_title(f"{titles[i]}\nDrift Score: {drift_score}")
    axes[i].legend()

    # Close the original figure to prevent display
    plt.close(dist_fig)

plt.tight_layout()
plt.show()