Source code for PTO.database.plot

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from dataclasses import dataclass
import logging
import pandas as pd
from ..utils.utilities import logger_default
from .mappers import LABEL_MAPPER, SCALE_MAPPER
import matplotlib.colors as mcolors

logger = logging.getLogger(__name__)
logger = logger_default(logger)


[docs] @dataclass class ColorPopulationDiagram: """ A class to define a theme for a color population diagram. Attributes ---------- theme : str The name of the theme. cmap : str The colormap used for the diagram. scatter : dict The scatter plot data kwargs. highlight_scatter : dict The highlighted scatter plot kwargs. """ theme: str cmap: str scatter: dict highlight_scatter: dict
RedPopulationDiagram = ColorPopulationDiagram( theme='red', cmap=mcolors.ListedColormap(sns.color_palette( "Reds", as_cmap=True)(np.linspace(0.2, 1, 256))), scatter={ 'color': sns.color_palette('muted')[3], 'edgecolor': 'black', 'alpha': 0.25, 's': 30 }, highlight_scatter={ 'color': sns.color_palette('bright')[0], 'edgecolor': 'black', 's': 200, 'zorder': 10 } ) BluePopulationDiagram = ColorPopulationDiagram( theme='blue', cmap=mcolors.ListedColormap(sns.color_palette( "Blues", as_cmap=True)(np.linspace(0.2, 1, 256))), scatter={ 'color': sns.color_palette('dark')[0], 'edgecolor': 'black', 'alpha': 0.2, 's': 30, }, highlight_scatter={ 'color': sns.color_palette('bright')[2], 'edgecolor': 'black', 's': 200, 'zorder': 10 } ) GreenPopulationDiagram = ColorPopulationDiagram( theme='green', cmap=mcolors.ListedColormap(sns.color_palette( "Greens", as_cmap=True)(np.linspace(0.2, 1, 256))), scatter={ 'color': sns.color_palette('dark')[2], 'edgecolor': 'black', 'alpha': 0.2, 's': 30, }, highlight_scatter={ 'color': sns.color_palette('bright')[0], 'edgecolor': 'black', 's': 200, 'zorder': 10 } ) GreyScalePopulationDiagram = ColorPopulationDiagram( theme='grayscale', cmap='Greys', scatter={ 'color': 'black', 'edgecolor': 'black', 'alpha': 0.1, 's': 30, }, highlight_scatter={ 'color': sns.color_palette('bright')[0], 'edgecolor': 'black', 's': 200, 'zorder': 10 } ) PurplePopulationDiagram = ColorPopulationDiagram( theme='purple', cmap=mcolors.ListedColormap(sns.color_palette( "Purples", as_cmap=True)(np.linspace(0.2, 1, 256))), scatter={ 'color': sns.color_palette('dark')[4], 'edgecolor': 'black', 'alpha': 0.2, 's': 30, }, highlight_scatter={ 'color': sns.color_palette('bright')[1], 'edgecolor': 'black', 's': 200, 'zorder': 10 } ) YellowPopulationDiagram = ColorPopulationDiagram( theme='yellow', cmap=mcolors.ListedColormap(sns.color_palette( "Wistia", as_cmap=True)(np.linspace(0.2, 1, 256))), scatter={ 'color': sns.color_palette('dark')[4], 'edgecolor': 'black', 'alpha': 0.2, 's': 30, }, highlight_scatter={ 'color': sns.color_palette('bright')[0], 'edgecolor': 'black', 's': 200, 'zorder': 10 } ) def _print_PopulationDiagramTheme(): """ Collects and logs the themes of all instances of ColorPopulationDiagram found in the global scope. This function iterates through all global variables, identifies instances of the ColorPopulationDiagram class, logs their themes using the logger, and collects these themes into a list which is then returned. Returns: list: A list of themes from all ColorPopulationDiagram instances found in the global scope. """ themes = [] for var_value in globals().values(): if isinstance(var_value, ColorPopulationDiagram): logger.info(f"{var_value.theme}") themes.append(var_value.theme) return themes def _get_PopulationDiagramTheme(theme: str): """ Returns the appropriate PopulationDiagram class based on the provided theme. Parameters ---------- theme : str The theme for the population diagram. Valid options are: 'red', 'green', 'blue', 'purple', 'greyscale', 'grayscale', 'grey', 'gray'. Returns ------- ColorPopulationDiagram The corresponding PopulationDiagram class for the given theme. Raises ------ ValueError If the provided theme is not valid. """ match theme: case 'red': return RedPopulationDiagram case 'green': return GreenPopulationDiagram case 'blue': return BluePopulationDiagram case 'purple': return PurplePopulationDiagram case 'greyscale' | 'grayscale' | 'grey' | 'gray': return GreyScalePopulationDiagram case 'yellow': return YellowPopulationDiagram case _: logger.warning('Invalid theme. Valid options are:') _print_PopulationDiagramTheme() raise ValueError('Not a valid theme')
[docs] class PlotUtilitiesComposite(): """ A utility class for creating and customizing population diagrams and highlighting samples on plots. Methods ------- _set_labels(ax: plt.Axes, x_key: str, y_key: str) Sets the x and y labels for the given axes based on the provided keys. _set_scales(ax: plt.Axes, x_key: str, y_key: str) Sets the x and y scales for the given axes based on the provided keys. plot_population_diagram(x_key: str, y_key: str, ax: plt.Axes | None = None, fig: plt.Figure | None = None, theme: str | ColorPopulationDiagram = 'blue', **kwargs) -> [plt.Figure, plt.Axes] Plots a population diagram on the given axes or creates new ones if not provided. Allows customization through themes and additional keyword arguments. highlight_sample(x_key: str, y_key: str, ax: plt.Axes | None = None, fig: plt.Figure | None = None, theme: str | ColorPopulationDiagram = 'blue', **kwargs) -> [plt.Figure, plt.Axes] Highlights a sample on the population diagram. If axes are not provided, it creates a new population diagram first. available_themes() Prints and returns the available themes for the population diagram. """ def _set_labels(self, ax: plt.Axes, x_key: str, y_key: str): """ Sets the x and y labels for the given axes based on the provided keys. Parameters ---------- ax : plt.Axes The matplotlib Axes object to set the labels on. x_key : str The key to retrieve the x-axis label from LABEL_MAPPER. y_key : str The key to retrieve the y-axis label from LABEL_MAPPER. """ ax.set_xlabel(LABEL_MAPPER[x_key]) ax.set_ylabel(LABEL_MAPPER[y_key])
[docs] def _set_scales(self, ax: plt.Axes, x_key: str, y_key: str): """ Sets the x and y scales for the given axes based on the provided keys. Parameters ---------- ax : plt.Axes The matplotlib Axes object to set the scales on. x_key : str The key to retrieve the x-axis scale from SCALE_MAPPER. y_key : str The key to retrieve the y-axis scale from SCALE_MAPPER. """ ax.set_xscale(SCALE_MAPPER[x_key]) ax.set_yscale(SCALE_MAPPER[y_key])
[docs] def plot_population_diagram(self, x_key: str, y_key: str, ax: plt.Axes | None = None, fig: plt.Figure | None = None, theme: str | ColorPopulationDiagram = 'blue', **kwargs ) -> [plt.Figure, plt.Axes]: """ Plots a population diagram on the given axes or creates new ones if not provided. This method generates a population diagram using the specified x and y keys for the data. It allows customization through themes and additional keyword arguments for the scatter plot. Parameters ---------- x_key : str The key to retrieve the x-axis data from the legacy_table. y_key : str The key to retrieve the y-axis data from the legacy_table. ax : plt.Axes | None, optional The matplotlib Axes object to plot on. If None, a new figure and axes are created. fig : plt.Figure | None, optional The matplotlib Figure object to plot on. If None, a new figure is created. theme : str | ColorPopulationDiagram, optional The theme for the population diagram. Can be a string representing the theme name or an instance of ColorPopulationDiagram. Default is 'blue'. **kwargs Additional keyword arguments for the scatter plot. Returns ------- [plt.Figure, plt.Axes] The matplotlib Figure and Axes objects with the population diagram. """ with sns.plotting_context('talk'): if ax is None: fig, ax = plt.subplots(1, figsize=(12, 8)) self._set_scales(ax, x_key, y_key) self._set_labels(ax, x_key, y_key) nan_indice = (self.legacy_table[x_key].notna() & self.legacy_table[y_key].notna() & (self.legacy_table[x_key] != 0) & (self.legacy_table[y_key] != 0)) nan_indice = self.legacy_table[nan_indice].index if type(theme) == str: Theme = _get_PopulationDiagramTheme(theme) else: Theme = theme for key in Theme.scatter.keys(): if key not in kwargs: kwargs[key] = Theme.scatter[key] sns.kdeplot( x=self.legacy_table.iloc[nan_indice][x_key], y=self.legacy_table.iloc[nan_indice][y_key], fill=True, thresh=0, levels=50, cmap=Theme.cmap, ax=ax, log_scale=(True, True) ) ax.scatter( x=self.legacy_table.iloc[nan_indice][x_key], y=self.legacy_table.iloc[nan_indice][y_key], label='__nolegend__', **kwargs ) return fig, ax
[docs] def highlight_sample(self, x_key: str, y_key: str, ax: plt.Axes | None = None, fig: plt.Figure | None = None, populate: bool = False, theme: str | ColorPopulationDiagram = 'blue', sns_context: str = 'talk', **kwargs ) -> [plt.Figure, plt.Axes]: """ Highlights a sample on the population diagram. This method highlights specific data points on an existing population diagram or creates a new one if axes are not provided. It uses the specified x and y keys for the data and allows customization through themes and additional keyword arguments for the scatter plot. Parameters ---------- x_key : str The key to retrieve the x-axis data from the table. y_key : str The key to retrieve the y-axis data from the table. ax : plt.Axes | None, optional The matplotlib Axes object to plot on. If None, a new population diagram is created. fig : plt.Figure | None, optional The matplotlib Figure object to plot on. If None, a new figure is created. populate : bool Whether to populate the exoplanet population, by default False. theme : str | ColorPopulationDiagram, optional The theme for highlighting the sample. Can be a string representing the theme name or an instance of ColorPopulationDiagram. Default is 'blue'. **kwargs Additional keyword arguments for the scatter plot. Returns ------- [plt.Figure, plt.Axes] The matplotlib Figure and Axes objects with the highlighted sample. """ with sns.plotting_context('talk'): if type(theme) == str: Theme = _get_PopulationDiagramTheme(theme) else: Theme = theme if fig is None or populate: fig, ax = self.plot_population_diagram( x_key=x_key, y_key=y_key, theme=Theme, ax=ax, fig=fig, ) for key in Theme.highlight_scatter.keys(): if key not in kwargs: kwargs[key] = Theme.highlight_scatter[key] ax.scatter( x=self.table[x_key], y=self.table[y_key], **kwargs, ) return fig, ax
[docs] def available_themes(self): logger.print('='*25) logger.print('Printing themes for the plot_diagram() method') themes = _print_PopulationDiagramTheme() logger.print('='*25) return themes