Source code for auto_prep.visualization.numerical

from typing import List, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from ..utils.config import config
from ..utils.logging_config import setup_logger
from ..utils.other import save_chart

logger = setup_logger(__name__)


[docs] class NumericalVisualizer: """ Contains methods that generate eda charts for numerical data. Will be fed with just numerical columns from original dataset. All methods will be called in order defined in :obj:`order`. Each method that would be called should return a tuple of (path_to_chart, chart title for latex) - if there is no need for chart generation should return ("", ""). Charts should be saved via :obj:`save_chart`. """ order = [ "numerical_distribution_chart", "correlation_heatmap_chart", "numerical_features_boxplot_chart", ]
[docs] @staticmethod def numerical_distribution_chart( X: pd.DataFrame, y: pd.Series, ) -> List[Tuple[str, str]]: """ Generates a plot to visualize the distribution of numerical features. """ logger.start_operation("Numerical distribution visualizations.") try: df = X numerical_columns = df.select_dtypes(include="number").columns.tolist() if not numerical_columns: logger.debug("No numerical features found in the dataset.") return [] logger.debug( "Will create numerical distribution visualisations" f"for {numerical_columns} columns." ) max_plots_per_page = 6 numerical_groups = [ numerical_columns[i : i + max_plots_per_page] for i in range(0, len(numerical_columns), max_plots_per_page) ] sns.set_theme( style=config.chart_settings["theme"], palette=config.chart_settings["palette"], ) chart_list = [] for group_idx, group in enumerate(numerical_groups): num_columns = len(group) num_rows = (num_columns + 1) // 2 fig, axes = plt.subplots( num_rows, 2, figsize=( config.chart_settings["plot_width"], config.chart_settings["plot_height_per_row"] * num_rows, ), ) axes = axes.flatten() for i, column in enumerate(group): sns.histplot( data=df, x=column, ax=axes[i], ) axes[i].set_title( f"Distribution of {column}", fontsize=config.chart_settings["title_fontsize"], fontweight=config.chart_settings["title_fontweight"], ) axes[i].set_xlabel( column, fontsize=config.chart_settings["xlabel_fontsize"] ) axes[i].set_ylabel( "Count", fontsize=config.chart_settings["ylabel_fontsize"] ) axes[i].tick_params( axis="x", rotation=config.chart_settings["tick_label_rotation"] ) axes[i].tick_params( axis="both", which="major", labelsize=config.chart_settings["xlabel_fontsize"], ) for j in range(i + 1, len(axes)): axes[j].axis("off") plt.tight_layout() pdf_path = save_chart( name=f"numerical_distribution_page_{group_idx + 1}.png" ) chart_list.append( ( pdf_path, f"Numerical Features Distribution - Page {group_idx + 1}", ) ) return chart_list except Exception as e: logger.error(f"Failed to generate numerical distribution plot: {str(e)}") raise e finally: logger.end_operation()
[docs] @staticmethod def correlation_heatmap_chart( X: pd.DataFrame, y: pd.Series, ) -> Tuple[str, str]: """ Generates a plot to visualize the correlation between features. """ logger.start_operation("Correlation heatmap visualization.") try: df = pd.concat([X, y], axis=1) numerical_columns = df.select_dtypes(include="number").columns.tolist() if not numerical_columns: logger.debug("No numerical features found in the dataset.") return "", "" max_columns_for_heatmap = 25 if len(numerical_columns) > max_columns_for_heatmap: logger.debug( f"Too many numerical columns ({len(numerical_columns)}) for heatmap visualization." ) return "", "Too many numerical columns for heatmap visualization." logger.debug(f"Will create heatmap for {numerical_columns} columns.") plt.figure( figsize=( config.chart_settings["plot_width"], config.chart_settings["plot_height_per_row"] * 2.5, ) ) sns.heatmap( df[numerical_columns].corr(numeric_only=False), annot=True, cmap=config.chart_settings.get("heatmap_cmap", "coolwarm"), fmt=config.chart_settings.get("heatmap_fmt", ".2f"), ) plt.title( "Correlation Heatmap", fontsize=config.chart_settings["title_fontsize"], fontweight=config.chart_settings["title_fontweight"], ) path = save_chart(name="correlation_heatmap.png") return path, "Correlation heatmap." except Exception as e: logger.error(f"Failed to generate correlation heatmap plot: {str(e)}") raise e finally: logger.end_operation()
[docs] @staticmethod def numerical_features_boxplot_chart( X: pd.DataFrame, y: pd.Series, ) -> List[Tuple[str, str]]: """ Generates boxplots for numerical features, split into multiple pages if necessary. """ logger.start_operation("Boxplot visualization for numerical features.") try: df = X numerical_columns = df.select_dtypes(include="number").columns.tolist() if not numerical_columns: logger.debug("No numerical features found in the dataset.") return [] logger.debug(f"Will create boxplot for {numerical_columns} columns.") num_columns = len(numerical_columns) plots_per_page = 6 num_pages = (num_columns + plots_per_page - 1) // plots_per_page paths = [] for page in range(num_pages): start_idx = page * plots_per_page end_idx = min(start_idx + plots_per_page, num_columns) columns_to_plot = numerical_columns[start_idx:end_idx] num_rows = (len(columns_to_plot) + 1) // 2 _, axes = plt.subplots( num_rows, 2, figsize=( config.chart_settings["plot_width"], config.chart_settings["plot_height_per_row"] * num_rows, ), ) axes = axes.flatten() try: for i, column in enumerate(columns_to_plot): sns.boxplot( data=df, x=column, ax=axes[i], palette=config.chart_settings["palette"], ) axes[i].set_title( f"Boxplot of {column}", fontsize=config.chart_settings["title_fontsize"], fontweight=config.chart_settings["title_fontweight"], ) axes[i].set_xlabel( column, fontsize=config.chart_settings["xlabel_fontsize"] ) axes[i].set_ylabel( "Value", fontsize=config.chart_settings["ylabel_fontsize"] ) axes[i].tick_params( axis="x", rotation=config.chart_settings["tick_label_rotation"], ) axes[i].tick_params( axis="both", which="major", labelsize=config.chart_settings["xlabel_fontsize"], ) except Exception as e: raise Exception(f"Error in column {column}") from e for j in range(i + 1, len(axes)): axes[j].axis("off") plt.tight_layout() path = save_chart(name=f"boxplot_page_{page + 1}.png") paths.append((path, f"Boxplot page {page + 1}")) return paths except Exception as e: logger.error(f"Failed to generate boxplot visualisations plot: {str(e)}") raise e finally: logger.end_operation()