Source code for auto_prep.visualization.categorical

from typing import List, Tuple

import matplotlib.pyplot as plt  # noqa: F401
import pandas as pd
import seaborn as sns  # noqa: F401

from ..utils.config import config
from ..utils.logging_config import setup_logger
from ..utils.other import save_chart  # noqa: F401

logger = setup_logger(__name__)


[docs] class CategoricalVisualizer: """ Contains methods that generate eda charts for categorical data. Will be fed with just categorical columns from original dataset. All methods will be called in order defined in :obj:`order`. Each method that would be called should return a tuple of (path_to_chart, chart title for latex) - if there is no need for chart generation should return ("", ""). Charts should be saved via :obj:`save_chart`. """ order = ["categorical_distribution_chart"]
[docs] @staticmethod def categorical_distribution_chart( X: pd.DataFrame, y: pd.Series ) -> List[Tuple[str, str]]: """ Generates a plot to visualize the distribution of categorical features. """ settings = config.chart_settings sns.set_theme(style=settings["theme"]) logger.start_operation("Categorical distribution visualization.") try: categorical_columns = X.select_dtypes(exclude=["number"]).columns.tolist() if not categorical_columns: logger.debug("No categorical features found in the dataset.") return [] categorical_columns = [ col for col in categorical_columns if X[col].nunique() <= 15 ] if not categorical_columns: logger.debug( "No categorical features with less than or equal to 15 unique values found in the dataset." ) return [] logger.debug( "Will create categorical distribution visualization chart" f"for {categorical_columns} columns." ) max_plots_per_page = 6 categorical_groups = [ categorical_columns[i : i + max_plots_per_page] for i in range(0, len(categorical_columns), max_plots_per_page) ] chart_list = [] for group_idx, group in enumerate(categorical_groups): num_columns = len(group) num_rows = (num_columns + 1) // 2 fig, axes = plt.subplots( num_rows, 2, figsize=( settings["plot_width"], settings["plot_height_per_row"] * num_rows, ), ) axes = axes.flatten() plot_count = 0 for column in group: sns.countplot( data=X, y=column, order=X[column].value_counts().index, ax=axes[plot_count], palette=sns.color_palette( settings["palette"], X[column].nunique() ), ) axes[plot_count].set_title( f"Distribution of {column}", fontsize=settings["title_fontsize"], fontweight=settings["title_fontweight"], ) axes[plot_count].set_xlabel( "Count", fontsize=settings["xlabel_fontsize"] ) axes[plot_count].set_ylabel( column, fontsize=settings["ylabel_fontsize"] ) plot_count += 1 for j in range(plot_count, len(axes)): axes[j].axis("off") plt.suptitle( f"Categorical Features Distribution - Page {group_idx + 1}", fontsize=settings["title_fontsize"], fontweight=settings["title_fontweight"], y=1.0, ) plt.tight_layout(pad=2.0) path = save_chart( name=f"categorical_distribution_page_{group_idx + 1}.png" ) chart_list.append( (path, f"Categorical Features Distribution - Page {group_idx + 1}") ) except Exception as e: logger.error(f"Error in categorical_distribution_chart: {e}") raise e finally: logger.end_operation() return chart_list