Source code for auto_prep.visualization.eda

from typing import Tuple

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from ..utils.config import config
from ..utils.logging_config import setup_logger
from ..utils.other import save_chart

logger = setup_logger(__name__)


[docs] class EdaVisualizer: """ Contains methods that generate basic eda charts. Will be fed with entire original dataset. All methods will be called in order defined in :obj:`order`. Each method that would be called should return a tuple of (path_to_chart, chart title for latex) - if there is no need for chart generation should return ("", ""). Charts should be saved via :obj:`save_chart`. """ order = [ "target_distribution_chart", "missing_values_chart", ]
[docs] @staticmethod def target_distribution_chart( X: pd.DataFrame, # noqa: F841 y: pd.Series, task: str = "classification", ) -> Tuple[str, str]: """ Generates a plot to visualize the distribution of the target variable. Args: X (pd.DataFrame): Input features (not used directly, included for API consistency). y (pd.Series): Target variable to visualize. task (str): Type of task, either "classification" or "regression". Returns: Tuple[str, str]: Path to the saved chart and a description of the chart. """ logger.start_operation("Target distribution visualization.") try: sns.set_theme( style=config.chart_settings["theme"], palette=config.chart_settings["palette"], ) y_df = y.to_frame(name="target") plt.figure(figsize=(10, 6)) if task == "classification": sns.countplot( data=y_df, x="target", palette=config.chart_settings["palette"], hue="target", ) total = len(y) for p in plt.gca().patches: height = p.get_height() plt.gca().text( p.get_x() + p.get_width() / 2, height + 3, f"{height / total:.2%}", ha="center", ) plt.title(f"Distribution of {y.name}") plt.xlabel("Target Classes") plt.ylabel("Count") path = save_chart(name="target_distribution_classification.png") description = "Target distribution." elif task == "regression": sns.histplot( data=y_df, x="target", bins=30, stat="density", ) mean_value = y.mean() median_value = y.median() plt.axvline( mean_value, color="red", linestyle="--", label=f"Mean: {mean_value:.2f}", ) plt.axvline( median_value, color="green", linestyle="--", label=f"Median: {median_value:.2f}", ) plt.legend() plt.title(f"Distribution of {y.name}") plt.xlabel("Target Value") plt.ylabel("Density") path = save_chart(name="target_distribution_regression.png") description = "Target distribution." else: raise ValueError( f"Unsupported task type: {task}. Use 'classification' or 'regression'." ) return path, description except Exception as e: logger.error( f"Failed to generate target distribution plot for {task}: {str(e)}" ) raise e finally: logger.end_operation()
[docs] @staticmethod def missing_values_chart( X: pd.DataFrame, y: pd.Series, # noqa: F841 ) -> Tuple[str, str]: """ Generates a plot to visualize the percentage of missing values for each feature in the given DataFrame. """ logger.start_operation("Missing values visualizations.") try: plt.figure(figsize=(10, 6)) missing = X.isnull().sum() / len(X) * 100 missing = missing[missing > 0].sort_values(ascending=False) if missing.empty: logger.debug("No missing values found in the dataset.") return "", "" logger.debug( f"Will create missing values chart for {list(missing.index)} columns." ) sns.barplot( x=missing.index, y=missing.values, palette=config.chart_settings["palette"], ) plt.xticks(rotation=45) plt.title("Percentage of Missing Values by Feature") path = save_chart(name="missing_values.png") return path, "Missing values." except Exception as e: logger.error(f"Failed to generate missing values plot: {str(e)}") raise e finally: logger.end_operation()