Source code for pytmosph3r.plot.comparison

import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning) # for x.p_id in ids issue
from typing import Literal, Optional

import matplotlib as mpl
import numpy as np
from scipy.interpolate import interp1d
from matplotlib import pyplot as plt, ticker

from .modelplot import CurvePlot, Plot
from .plotutils import *


[docs] class Comparison(Plot): """Compare (and plot) multiple models.""" def __init__(self, models=None,title=None,suffix=None,cmap='Paired',out_folder='.', interactive=None): """Select models to compare with :attr:`models`.""" BasePlot.__init__(self, self.__class__.__name__) self.models = np.array([x for x in models if x is not None]) if len(self.models) <= 1: self.warning("No models to compare ("+str(len(models))+")") self.title = title if interactive in (False, True): self.interactive = interactive self.cmap = mpl.cm.get_cmap(cmap) self.suffix=suffix if self.suffix is None: self.suffix = "comp_output" self.out_folder=out_folder if not os.path.exists(self.out_folder): os.makedirs(self.out_folder) @property def wns(self): try: return self.models[0].wns except: return None @property def wls(self): try: return self.models[0].wls except: return None
[docs] def transmittance_map(self, ids=None, *args, **kwargs): """Same parameters as :func:`.plot.Plot.transmittance_map`. Args: ids (list, optional) : List of ids (`label` parameter) of models to plot. For example ['1D', '2D', '3D'] if you created :code:`Plot(model, label='1D')`, etc. """ if ids is None: ids=np.array([m.p_id for m in self.models]) for model in filter(lambda x: x.p_id in ids, self.models): model.transmittance_map(*args, **kwargs)
[docs] def plot_spectra(self, mode=None, title="Spectra", ids=None, ax=None, figsize=(4,4), legend=True, ref=0, savename=None, func="plot_spectrum", *args, **kwargs): """Plot spectra (select `mode` for emission/transmission, etc). Args: mode (str, optional): Among [transmission, emission, lightcurve, phasecurve]. Defaults to None (in which case the default spectrum is chosen). ids (list, optional) : List of ids (`label` parameter) of models to plot. For example ['1D', '2D', '3D'] if you created :code:`Plot(model, label='1D')`, etc. ref (int, optional): Index of the reference curve (plot as dashed). Defaults to None. savename (str, optional): Name of output file. Defaults to title. func (str, optional): Plot function to call for each model. Defaults to :func:`~pytmosph3r.plot.modelplot.ModelPlot.plot_spectrum`. """ fig, ax, save = self.figure(ax, figsize) if ids is None: ids=np.array([m.p_id for m in self.models]) dashes=[] for i, model in enumerate(list(filter(lambda x: x.p_id in ids, self.models))): if i == ref % len(self.models): # plot reference curve as dashed dashes=[5,2] err = getattr(model, func)(mode=mode, ax=ax, dashes=dashes, color=None, label=model.label, title=title, *args, **kwargs) # call plot_spectrum() by default if err: return err # error somewhere, don't plot if legend: # in case you don't need it ax.legend() if save: plt.tight_layout() if savename is None: savename = title.lower() if mode is not None: savename = f"{savename}_{mode}" self.save_plot(savename)
[docs] def plot_curves(self, mode="lightcurve", title="Lightcurves", wl=None, wn=None, *args, **kwargs): """Same parameters as :func:`~pytmosph3r.plot.modelplot.CurvePlot.plot_curve`. You can select IDs the same way as for :func:`plot_spectra`.""" return self.plot_spectra(mode=mode, title=title, wl=wl, wn=wn, savename=title.lower(), func="plot_curve", *args, **kwargs)
[docs] def plot_lightcurves(self, *args, **kwargs): """Same arguments as :func:`plot_curves`.""" return self.plot_curves(mode="lightcurve", title="Lightcurves", *args, **kwargs)
[docs] def plot_phasecurves(self, *args, **kwargs): """Same arguments as :func:`plot_curves`.""" return self.plot_curves(mode="phasecurve", title="Phasecurves", *args, **kwargs)
[docs] def diff_fluxes(self, mode=None, title="Spectra", ids=None, ax=None, time=None, phase=None, wl=None, wn=None, xlog=True, ylog=False, ppm=None, abs=False, resolution=None, figsize=(4,4), ylabel="Residual", x_axis="wls", x_units=None, ref=-1, ref_phase=None, savename=None, *args, **kwargs): """Compare fluxes together. Args: mode (str, optional): Among [transmission, emission, lightcurve, phasecurve]. Defaults to None (in which case the default spectrum is chosen). ids (list, optional) : List of differences of ids (`label` parameter). Example of use: :code:`comparison.diff_fluxes(ids=[["3D", "1D"],["2D", "1D"]])` if you created :code:`Plot(model, label='1D')`, etc.. ref (int, optional): Index of model to take as a reference. Defaults to -1. ref_phase (float, optional) : if you want to compare the flux at `phase` against `ref_phase`. ppm (bool) : Y units in ppm or not. x_axis (str) : Choose X axis as "wls" or "wns", for wavelengths or wavenumbers, respectively. abs (bool, optional): Plot absolute difference. Defaults to False. """ fig, ax, save = self.figure(ax, figsize) label = None unit = "" suffix = "" p_label = "" if ppm is None and mode in (None, "transmission", "lightcurve"): ppm = True unit = " ppm" # INFO: Set default value to x_label, to prevent it to be undefined xlabel = '' # we assume that all curves have the same spectral dimension here ws, wl, wn, w_units = self.models[ref].init_spectral(wl, wn) phase, time, label_prefix, time_units = self.models[ref].init_time(phase, time) if ids is None: ids = [[m.p_id, self.models[ref].p_id] for i, m in enumerate(self.models) if i != ref % len(self.models)] if ref_phase is None: ref_phase = np.atleast_1d(phase) for comp in ids: model0 = list(filter(lambda x: comp[0] == x.p_id, self.models)) model1 = list(filter(lambda x: comp[1] == x.p_id, self.models)) if len(model0) and len(model1): model0 = model0[0] model1 = model1[0] # either make phase OR wn vary, not both for j, ph in enumerate(np.atleast_1d(phase)): for i, w in enumerate(np.atleast_1d(wn)): try: spectrum0 = model0.flux(mode, phase=ph, wn=w) ref = ref_phase[j] if ref_phase is not None else ph spectrum1 = model1.flux(mode, phase=ref, wn=w) except Exception as e: continue if spectrum0 is None or spectrum1 is None: continue if resolution: spectrum0 = self.bin_down(resolution, spectrum0) spectrum1 = self.bin_down(resolution, spectrum1) # get x if x_axis in ("phases","times"): x_0, xlabel = model0.x_axis_curve(x_axis=x_axis, x_units=x_units, mode=mode) x_1, xlabel = model1.x_axis_curve(x_axis=x_axis, x_units=x_units, mode=mode) else: x_0 = getattr(spectrum0, x_axis) x_1 = getattr(spectrum1, x_axis) xlabel = label_from_dim(x_axis) # get y y_1 = spectrum1.value # if models don't have the same spectral ranges y_0 = interp1d(x_0, spectrum0.value, assume_sorted=False)(x_1) diff = (y_0 - y_1) if ppm: diff *= 1e6 if abs: diff = np.abs(diff) err = np.mean(np.abs(diff)) # labelling if np.atleast_1d(time)[0] is not None: t, units = time_fmt(time[model1.ph_index], time_units) suffix = f" @ {t:.2f} {units}" elif ph is not None: suffix = f" @ {np.degrees(model1.getattr(mode, 'phases')[model1.ph_index]):.1f}°" if w is not None: suffix = f" @ {get_spectral(10000/w, w, w_units):.1f} {units_from_dim(w_units)}" if (np.ndim(phase) and len(phase)>1) or (np.ndim(wn) and len(wn)>1): p_label = suffix label = f"{model0.label} - {model1.label}{p_label} = {err:.3g}{unit}" ax.plot(x_1, diff, label=label, *args, **kwargs) if ylog: ax.set_yscale('log') if xlog: ax.set_xscale('log') ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f')) # ax.xaxis.set_minor_formatter(ticker.FormatStrFormatter('%.1f')) ax.set_xlabel(xlabel) if ppm: ax.set_ylabel(f"{ylabel} [ppm]") else: ax.set_ylabel(f"{ylabel}") #ax.grid() if label is None: return 1 # no plot if save: ax.legend() if (np.ndim(phase) and len(phase)<2) or (np.ndim(wn) and len(wn)<2): p_title = f"{suffix}" if title and suffix: p_title = f"{title} @ {suffix}" ax.set_title(f"{p_title}") plt.tight_layout() if savename is None: savename = title.lower() self.save_plot(f"diff_{savename}") return 0
[docs] def diff_spectra(self, phase=None, *args, **kwargs): """ Compare models using ids. Example of use: :code:`comparison.diff_spectra(ids=[["3D", "1D"],["2D", "1D"]])` if you created :code:`Plot(model, label='1D')`, etc.. See :func:`diff_fluxes` for more information. """ return self.diff_fluxes(phase=phase, *args, **kwargs)
[docs] def diff_curves(self, x_axis: Literal["times","phases"] = "times", xlog=False, *args, **kwargs): """ Compare phase/light-curves together. See :func:`diff_fluxes` for more information. """ return self.diff_fluxes(xlog=xlog, x_axis=x_axis, *args, **kwargs)
[docs] def diff_lightcurves(self, wl=None, wn=None, *args, **kwargs): """See :func:`diff_fluxes` for more information.""" return self.diff_curves(mode="lightcurve", wl=wl, wn=wn, *args, **kwargs)
[docs] def diff_phasecurves(self, wl=None, wn=None, *args, **kwargs): """See :func:`diff_fluxes` for more information.""" return self.diff_curves(mode="phasecurve", wl=wl, wn=wn, *args, **kwargs)
[docs] def plot_tp(self, ax=None, title="PT profile", logx=False, logy=True, figsize=(9,3)): """TP profile of one column.""" fig, ax, save = self.figure(ax, figsize=(9,3)) for model in self.models: self.plot_column(ax, model.temperature, model.pressure, label=model.label) if logy: ax.set_yscale('log') if logx: ax.set_xscale('log') if save: ax.set_title(title) self.tp_legend(ax, fig=ax.get_figure()) self.save_plot("tp")
[docs] def plot_zp(self, ax=None, title="ZP profile", logx=True, logy=False, figsize=(9,3)): """ZP profile of one column.""" fig, ax, save = self.figure(ax, figsize=figsize) for model in self.models: self.plot_column(ax, model.pressure, model.z, label=model.label) if logy: ax.set_yscale('log') if logx: ax.set_xscale('log') # ax.legend() if save: ax.set_title(title) self.zp_legend(ax, fig) self.save_plot("zp")
[docs] def plot_xprofile(self, ax=None, figsize=(9,3), *args, **kwargs): """Mixing ratio. longitude = 1 plots the terminator. Outdated?""" fig, ax, save = self.figure(ax, figsize=figsize) num_models = len(self.models) xmin=1 gas_legends = {} models = [] for model_idx, model in enumerate(self.models): dashes = [(num_models-model_idx)+2, (model_idx*3)+3] model_gas_legend, model_min, model_max = model.plot_xprofile(ax, dashes=dashes, *args, **kwargs) try: xmin = min(model_min, xmin) except: pass models.append(mpl.lines.Line2D([0], [0], dashes=dashes, label=model.label)) gas_legends.update(model_gas_legend) ax.set_xlim(max(1e-12,xmin)) if save: self.save_plot("vmr") return models, gas_legends
[docs] def x_legend(self, axes, fig, legends): """Place legend with model + gas labels.""" if isinstance(axes, (np.ndarray)): self.legend2D(axes) ax = axes.flatten()[0] else: ax = axes ax.invert_yaxis() plt.xlabel('Mixing ratio') plt.ylabel('Pressure (Pa)') plt.tight_layout() fig.subplots_adjust(right=0.9, wspace=0.25, hspace=0.35) legends plt.gca().add_artist(fig.legend(handles=legends[0], loc=1)) plt.gca().add_artist(fig.legend(handles=legends[1].values(), loc=4))
[docs] def tp_legend(self, axes, fig, *args, **kwargs): Plot.tp_legend(self, axes, *args, **kwargs) self.comp_legend(axes, fig, *args, **kwargs)
[docs] def zp_legend(self, axes, fig, *args, **kwargs): Plot.zp_legend(self, axes, *args, **kwargs) self.comp_legend(axes, fig, *args, **kwargs)
[docs] def comp_legend(self, axes, fig, *args, **kwargs): """Place legend with model labels.""" if isinstance(axes, (np.ndarray)): self.legend2D(axes) ax1 = axes.flatten()[-1] else: ax1 = axes plt.tight_layout() h, labels = ax1.get_legend_handles_labels() fig.legend(h, labels, loc=4)
[docs] def legend2D(self, axes): """Legend for rows and columns (latitudes and longitudes) when using :func:`Plot.plot_columns`.""" if not hasattr(axes, "__len__") or len(axes.flatten()) == 1: return for ax, lat in zip(axes[:,0], self.latitudes): ax.annotate("Latitude:\n %s"% lat, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - 10, 0), xycoords=ax.yaxis.label, textcoords='offset points', size='large', ha='right', va='center') for ax, lon in zip(axes[0], self.longitudes): ax.annotate("Longitude:\n %s"% lon, xy=(0.5, 1), xytext=(0, 5), xycoords='axes fraction', textcoords='offset points', size='large', ha='center', va='baseline')
[docs] def plot_diff_spectra(self, mode=None, plots=None, compares=None, suffix=None, abs=False, figsize=(9,5), *args, **kwargs): """Plot spectra and their differences. See parameters of :func:`plot_spectra` and :func:`diff_spectra` for more information. Args: mode (str, optional): Among [transmission, emission, lightcurve, phasecurve]. Defaults to None (in which case the default spectrum is chosen). plots (list, optional) : List of ids (`label` parameter) of models to plot. see parameter `ids` of :func:`plot_spectra`. compares (list, optional) : List (`label` parameter) of differences of ids to plot. see parameter `ids` of :func:`diff_fluxes`. phase (ndarray, optional): List of phases to plot (in curve modes only) in degrees. """ fig, ax = plt.subplots(nrows=2, ncols=1, figsize=figsize, gridspec_kw={'height_ratios': [3, 1]}, sharex=True) err = self.plot_spectra(mode=mode, ids=plots, ax=ax[0], xlabel='', *args, **kwargs) if err: return 1 err = self.diff_spectra(mode=mode, ids=compares, ax=ax[1], abs=abs, *args, **kwargs) if err: return 1 ax[0].legend(loc="upper left", bbox_to_anchor=(1, 1.04), ncol=int(np.ceil(len(ax[0].lines)/7))) ax[1].legend(loc="lower left", bbox_to_anchor=(1, -0.06), ncol=int(np.ceil(len(ax[1].lines)/11))) if suffix is None: suffix=f"{mode}_{self.suffix}" self.save_plot("spectra_diff", suffix=suffix)
[docs] def plot_diff_curves(self, mode="lightcurve", plots=None, compares=None, suffix=None, abs=False, figsize=(9,5), *args, **kwargs): """Plot light/phase-curves and their differences. See parameters of :func:`plot_curves` and :func:`diff_curves` for more information. Args: plots (list, optional) : List of ids (`label` parameter) of models to plot. see parameter `ids` of :func:`plot_spectra`. compares (list, optional) : List (`label` parameter) of differences of ids to plot. see parameter `ids` of :func:`diff_fluxes`. wl/wn (float, optional): Wavelength/wavenumber of the curve (can be a list). """ fig, ax = plt.subplots(nrows=2, ncols=1, figsize=figsize, gridspec_kw={'height_ratios': [3, 1]}, sharex=True) err = self.plot_curves(mode=mode, ids=plots, ax=ax[0], xlabel='', *args, **kwargs) if err: return 1 err = self.diff_curves(mode=mode, ids=compares, ax=ax[1], abs=abs, *args, **kwargs) if err: return 1 ax[0].legend(loc="upper left", bbox_to_anchor=(1, 1.04), ncol=int(np.ceil(len(ax[0].lines)/7))) ax[1].legend(loc="lower left", bbox_to_anchor=(1, -0.06), ncol=int(np.ceil(len(ax[1].lines)/11))) # plt.tight_layout() self.save_plot(f"{mode}s_diff", suffix=suffix)
[docs] def plot_diff_lightcurves(self, *args, **kwargs): """See :func:`plot_diff_curves` for more information.""" return self.plot_diff_curves(mode="lightcurve", *args, **kwargs)
[docs] def plot_diff_phasecurves(self, *args, **kwargs): """See :func:`plot_diff_curves` for more information.""" return self.plot_diff_curves(mode="phasecurve", *args, **kwargs)
[docs] def plot_2d_fluxes_residuals(self, mode="lightcurve", ax=None, title="Lightcurve residuals", ids=None, ref=0, ppm=True, x_axis="wls", figsize=(5,3.5), savename="residuals_2d", colorbar_kwargs={}, **kwargs): """Plot curve residuals (select `mode` for light/phase-curve). See :func:`~pytmosph3r.plot.modelplot.CurvePlot.plot_curves` for more parameters. Args: mode (str, optional): Among [lightcurve, phasecurve]. Defaults to lightcurve. ids (list, optional) : List of ids (`label` parameter) of models to plot. For example ['1D', '2D', '3D'] if you created :code:`Plot(model, label='1D')`, etc. ref (int, optional): Index of the reference curve (plot as dashed). Defaults to None. savename (str, optional): Prefix for output filenames. Defaults to 'residuals_2d'. """ if ids is None: ids = [[m.p_id, self.models[ref].p_id] for i, m in enumerate(self.models) if i != ref % len(self.models)] for i, comp in enumerate(ids): model0 = list(filter(lambda x: comp[0] == x.p_id, self.models)) model1 = list(filter(lambda x: comp[1] == x.p_id, self.models)) if len(model0) and len(model1): model0 = model0[0] model1 = model1[0] flux0 = model0.curve(mode) flux1 = model1.curve(mode) flux = flux0-flux1 colorbar_kwargs['label'] = None if ppm: flux *= 1e6 colorbar_kwargs['label'] = 'ppm' if 'format' not in colorbar_kwargs: colorbar_kwargs['format'] = '%.1f' else: if 'format' not in colorbar_kwargs: colorbar_kwargs['format'] = '%.3f' x = model0.get_spectral(w_units=x_axis, mode=mode)[0] y = np.degrees(model0.mode(mode).phases) err = CurvePlot.plot_2d_flux(self, mode=mode, ax=ax, title=title, figsize=figsize, x=x, y=y, flux=flux, x_axis=x_axis, savename=f"{savename}_{i}", colorbar_kwargs=colorbar_kwargs, **kwargs)
[docs] def plot_2d_lightcurves_residuals(self, *args, **kwargs): """See :func:`plot_2d_fluxes_residuals`.""" return self.plot_2d_fluxes_residuals(*args, **kwargs)
[docs] def plot_2d_phasecurves_residuals(self, title='Phasecurve residuals', mode='phasecurve', *args, **kwargs): """See :func:`plot_2d_fluxes_residuals`.""" return self.plot_2d_fluxes_residuals(mode=mode, title=title, *args, **kwargs)