from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib as mpl
import matplotlib.pylab as pl
from ai import cs
from matplotlib import ticker
from matplotlib.transforms import Affine2D
from .plotutils import *
from ..interface.hdf5 import HDF5Input
from ..util.util import *
if TYPE_CHECKING:
from ..model import Model
[docs]
class ModelPlot(BasePlot):
"""Intermediary class that can be inherited from by classes from pytmosph3r, to make plots from them directly."""
[docs]
def plot_rays(self, points=True, mid_points=False, rays=False, rays_bottom=False, rays_top=True,
rays_terminator=True, figsize=None, mode="transmission"):
"""Plot rays with matplotlib.
Args:
rays_bottom (bool, optional): Display the bottom layer (surface) of the planet. Defaults to False.
rays_top (bool, optional): Display the top layer of the planet. Defaults to False.
rays_terminator (bool, optional): Display the terminator plane. Defaults to False.
"""
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection='3d')
lat, lon = np.meshgrid(self.grid.mid_latitudes, self.grid.all_longitudes)
if rays_bottom:
points_b = cs.sp2cart(self.R, lat, lon)
s = ax.plot_surface(points_b[0], points_b[1], points_b[2], label="surface", alpha=.3)
s._facecolors2d = s._facecolors3d
s._edgecolors2d = s._edgecolors3d
if rays_top:
points_t = cs.sp2cart(self.r.max(), lat, lon)
s = ax.plot_wireframe(points_t[0], points_t[1], points_t[2], label="top", alpha=.4)
if rays_terminator:
# terminator plane
scale = 1
num = 2
A = self.rays(mode).cartesian_system.direction.x
B = self.rays(mode).cartesian_system.direction.y
C = self.rays(mode).cartesian_system.direction.z
if C != 0:
x = np.linspace(-self.R * scale, self.R * scale, num)
y = np.linspace(-self.R * scale, self.R * scale, num)
X, Y = np.meshgrid(x, y)
Z = -(A * X + B * Y) / C
elif B != 0:
x = np.linspace(-self.R * scale, self.R * scale, num)
z = np.linspace(-self.R * scale, self.R * scale, num)
X, Z = np.meshgrid(x, z)
Y = -(A * X + C * Z) / B
else:
y = np.linspace(-self.R * scale, self.R * scale, num)
z = np.linspace(-self.R * scale, self.R * scale, num)
Y, Z = np.meshgrid(y, z)
X = -(C * Z + B * Y) / A
s = ax.plot_surface(X, Y, Z, label="terminator", alpha=.5)
s._facecolors2d = s._facecolors3d
s._edgecolors2d = s._edgecolors3d
try:
if points:
self.plot_points(ax, self.rays(mode).points)
if mid_points:
self.plot_points(ax, self.rays(mode).mid_points)
except:
self.debug("No points in output file. Try running pytmosph3r with -v")
if rays:
raise NotImplementedError
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
self.save_plot("rays")
[docs]
def plot_points(self, ax, points,mode=None):
if isinstance(points, (dict,)):
iterator = points.items()
elif isinstance(points, (list,)):
iterator = enumerate(points)
if isinstance(points[0], (list,)):
for p_angle in points:
for p_radius in p_angle:
self.plot_points_ray(ax, p_radius)
return
elif isinstance(points, (np.ndarray,)):
if mode is None:
raise ValueError('`mode` should be set.')
iterator = enumerate([points[radius, angle] for radius, angle in self.rays(mode).walk()])
for i, ray in iterator:
self.plot_points_ray(ax, ray)
[docs]
def plot_points_ray(self, ax, ray):
if len(ray) < 1: # no points
return
if isinstance(ray, dict):
points = cs.sp2cart(ray["radius"] / self.h_unit + self.R, ray["latitude"], ray["longitude"])
else:
points = cs.sp2cart(ray[:, 1] / self.h_unit, ray[:, 2], ray[:, 3])
ax.plot(points[0], points[1], points[2])
[docs]
def plot_2Dmap(self, ax, location, dim, x, y, z, p_levels=None, cmap="YlOrRd",
log=False, vmin=None, imshow=False, figsize=(5, 3.5), *args, **kwargs):
"""Plot a 2D map at a specific location and dimension (core function). Called by :func:`map_2D`.
Args:
location (str, int): Name ("equator", ...) or index of location to plot
dim (str, int): Dimension of location (altitude/latitude/longitude)
x (ndarray): Meshgrid
y (ndarray): Meshgrid
z (ndarray): Values to plot
p_levels (list, optional): Pressure levels to plot over the map. Defaults to [1e-4, 1, 100, 10**4].
cmap (str, optional): Colormap to be used. Defaults to "YlOrRd".
log (bool): Log scale for colors. Defaults to False.
vmin (float): Minimum value for colorbar.
vmax (float): Maximum value for colorbar.
imshow (bool): If True, the map will use plt.imshow() instead of plt.contourf(). imshow() shows exactly the temperature map used, while contourf() makes it smoother. Defaults to False (i.e., contourf).
"""
if p_levels is None:
p_levels = [1e-4, 1, 100, 10 ** 4]
hz = get_2D(z, location, dim)
if isinstance(hz, (float, str)):
hz = np.full((len(x), len(y)), hz)
if hz.ndim < 2 or 1 in hz.shape:
# check which dim is 0D
dim_0D = hz.shape.index(1)
aspect = "1"
if dim_0D:
aspect = ".1"
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
cs = ax.imshow(hz, aspect=aspect, cmap=cmap, vmin=vmin,
*args, **kwargs)
plt.colorbar(cs)
ax.set_yticklabels(['%.2f' % i for i in y[0].tolist()])
if not dim_0D:
ax.set_xticklabels(['%g' % i for i in x[0].tolist()])
if dim == "latitude":
ax.set_xlabel('East Longitude')
ax.set_ylabel('Altitude (Mm)')
elif dim == "longitude":
ax.set_xlabel('Latitude')
ax.set_ylabel('Altitude (Mm)')
elif dim == "altitude":
ax.set_xlabel('East Longitude')
ax.set_ylabel('Latitude')
return ax
if dim != "longitude":
hz = np.concatenate((hz, hz[:, 0:1]), axis=1)
if dim != "altitude":
if p_levels is not None:
zp = get_2D(self.pressure, location, dim)
if dim == "latitude":
zp = np.concatenate((zp, zp[:, 0:1]), axis=1)
ax.contour(x, y, zp, colors="black", linewidths=.2,
locator=ticker.FixedLocator(p_levels), )
locator = ticker.LinearLocator(100)
formatter = None
extend = 'neither'
if log:
locator = ticker.LogLocator(base=1.01, subs=(1.0,), numticks=100)
formatter = ticker.LogFormatter(1.01, labelOnlyBase=False)
if vmin:
hz[np.where(hz < vmin)] = vmin
extend = 'min'
if imshow:
cs = ax.imshow(hz, extent=[x.min(), x.max(), y.min(), y.max()], cmap=cmap, vmin=vmin, *args,
**kwargs)
else:
cs = ax.contourf(x, y, hz, cmap=cmap, vmin=vmin, locator=locator, extend=extend, *args, **kwargs)
plt.colorbar(cs, format=formatter, pad=0.08)
return None
[docs]
def plot_2D(self, func, dim=None, altitudes=None, latitudes=None, longitudes=None, *args, **kwargs):
"""Calls :attr:`func` on all `locations` of :attr:`dim` for a 2D plot. Can also select altitudes, latitudes and longitudes separately."""
if altitudes is not None:
loop = altitudes
dim = "altitude"
elif latitudes is not None:
loop = latitudes
dim = "latitude"
elif longitudes is not None:
loop = longitudes
dim = "longitude"
elif dim == "altitude":
loop = self.altitudes
elif dim == "latitude":
loop = self.latitudes
elif dim == "longitude":
loop = self.longitudes
else:
warnings.warn(
"Dimension '%s' not recognized. Should be among 'altitude', 'latitude' or 'longitude'. Not plotting 2D." % dim)
loop = []
if isinstance(loop, (float, int)):
loop = [loop]
for location in loop:
func(location=location, dim=dim, *args, **kwargs)
[docs]
def map_2D(self, array, location="equator", dim="latitude", ax=None, figsize=(5, 3.5), *args, **kwargs):
"""Generic 2D map plot for a specific dimension & location. Selects the data to send to :func:`plot_2Dmap`."""
fig = plt.figure(figsize=figsize)
vertical = self.r # by default, altitude is used as a vertical measure
if dim == "altitude":
ax = fig.add_subplot(111)
longitudes = np.concatenate((self.grid.mid_longitudes, self.grid.mid_longitudes[0:1] + 2 * np.pi))
x, y = np.degrees(np.meshgrid(longitudes, self.grid.mid_latitudes))
ax.set_xlabel('East Longitude')
ax.set_ylabel('Latitude')
elif dim == "latitude":
if self.grid.n_longitudes > 1: # we will do a normal plot in plot_2Dmap
ax = fig.add_subplot(111, projection='polar')
longitudes = np.concatenate((self.grid.mid_longitudes, self.grid.mid_longitudes[0:1] + 2 * np.pi))
x, y = np.meshgrid(longitudes, self.r)
if self.r.ndim > 1:
vertical = np.log10(self.pressure[:, 0,
0]) # in Emission mode, we don't compute the altitude so we use the pressure as a vertical measure, supposed to be 1D
x, y = np.meshgrid(longitudes, vertical)
ax.set_rlim(bottom=vertical.max(), top=vertical.min())
elif dim == "longitude":
if self.grid.n_latitudes > 1: # we will do a normal plot in plot_2Dmap
ax = fig.add_subplot(111, projection='polar')
x, y = np.meshgrid(self.grid.mid_latitudes, self.r)
if self.r.ndim > 1:
vertical = np.log10(self.pressure[:, 0,
0]) # in Emission mode, we don't compute the altitude so we use the pressure as a vertical measure, supposed to be 1D
x, y = np.meshgrid(self.grid.mid_latitudes, vertical)
ax.set_rlim(bottom=vertical.max(), top=vertical.min())
newax = self.plot_2Dmap(ax=ax, location=location, dim=dim, x=x, y=y, z=array, *args, **kwargs)
if newax:
ax = newax
if dim == "longitude":
if self.grid.mid_latitudes[0] != self.grid.mid_latitudes[-1]:
ax.set_xlim(self.grid.mid_latitudes[0], self.grid.mid_latitudes[-1])
elif dim == "latitude":
try:
ax.set_theta_zero_location("W")
angles = [45 * theta for theta in range(0, 8)]
ax.set_xticks(np.deg2rad(angles))
ax.set_xticklabels([f'{theta}°' for theta in angles], fontsize=8)
except: # in case of 1x1 maps
plt.tight_layout(pad=2)
else:
plt.tight_layout(pad=2)
if dim != "altitude" and not newax:
ax.grid(linewidth=.1)
if self.r.ndim == 1:
ax.set_rmin(0)
ax.set_yticks([self.r.min(), self.r.max()])
ax.set_yticklabels(
[f'{x:,.1f} Mm' for x in [self.z_levels[0], self.z_levels[-1]]], fontsize=6)
ax.set_rgrids([self.r.min(), self.r.max()])
else:
p_levels = np.array([-4, 0, 2, 4.])
p_levels = np.insert(p_levels, np.searchsorted(p_levels, vertical.max()), vertical.max())
p_levels = np.insert(p_levels, np.searchsorted(p_levels, vertical.min()), vertical.min())
p_levels = p_levels[np.where((p_levels >= vertical.min()) & (p_levels <= vertical.max()))][
::-1]
ax.set_rgrids(p_levels)
ax.set_yticklabels(["{:,.1g}".format(x) + ' Pa' for x in np.power(10, p_levels)], fontsize=6)
ax.set_rlabel_position(80)
ax.tick_params(pad=0)
return ax
[docs]
def t_map(self, location="equator", dim="latitude", ax=None, cmap="gnuplot2", *args, **kwargs):
"""Temperature 2D map for a specific dimension & location (calls :func:`map_2D` with identical parameters)."""
ax = self.map_2D(self.temperature, location=location, dim=dim, ax=ax, cmap=cmap, *args, **kwargs)
index = get_index(self.grid, location, dim)
dim_display = dim
if self.vertical_in_pressure:
dim_display = "pressure"
ax.set_title("Temperature (K) at %s %s" % (dim_display, self.get_value_dim(index, dim)))
self.save_plot("t_map_%s_%s" % (dim, index))
[docs]
def t_maps(self, dim="latitude", *args, **kwargs):
"""Temperature 2D maps over multiple locations. See :func:`plot_2Dmap` for further parameters. You can select altitudes, latitudes and longitudes using arguments (see :func:`plot_2D`) or by setting them beforehand:
- :attr:`self.altitudes <altitudes>` when :attr:`dim = "altitude"`,
- :attr:`self.latitudes <latitudes>` when :attr:`dim = "latitude"` (default),
- :attr:`self.longitudes <longitudes>` when :attr:`dim = "longitude"`
"""
self.plot_2D(self.t_map, dim, *args, **kwargs)
[docs]
def x_map(self, gas=None, location="equator", dim="latitude", cmap="PuBuGn", ax=None, *args, **kwargs):
"""VMR 2D map for a specific dimension & location (calls :func:`map_2D` with identical parameters)."""
if gas not in self.gas_mix_ratio:
self.error(f"Gas {gas} not in mix ratio.")
return
if isinstance(self.gas_mix_ratio[gas], str):
total_vmr = sum([x for x in self.gas_mix_ratio.values() if not isinstance(x, str)])
vmr = 1 - total_vmr * np.ones(self.shape)
else:
vmr = self.gas_mix_ratio[gas] * np.ones(self.shape)
vmin = max(np.min(vmr), 1e-16)
if vmin > 1e-16:
vmin = None
ax = self.map_2D(vmr, location=location, dim=dim, ax=ax, cmap=cmap, log=True, vmin=vmin, *args,
**kwargs)
index = get_index(self.grid, location, dim)
dim_display = dim
if self.vertical_in_pressure:
dim_display = "pressure"
ax.set_title("[%s] at %s %s" % (gas, dim_display, self.get_value_dim(index, dim)))
# self.zp_legend(ax, fig)
os.makedirs(os.path.join(self.out_folder, gas), exist_ok=True)
self.save_plot(os.path.join(gas, f"{gas}_map_{dim}_{index}"))
[docs]
def x_maps(self, gases=None, dim="latitude", *args, **kwargs):
"""Gas Volume Mixing ratio 2D maps over multiple locations. See :func:`plot_2Dmap` for further parameters. You should set beforehand:
- :attr:`self.altitudes <altitudes>` when :attr:`dim = "altitude"` (default),
- :attr:`self.latitudes <latitudes>` when :attr:`dim = "latitude"`,
- :attr:`self.longitudes <longitudes>` when :attr:`dim = "longitude"`
"""
if gases is None:
gases = self.gas_mix_ratio
if isinstance(gases, str):
gases = [gases]
for gas in gases:
self.plot_2D(self.x_map, gas=gas, dim=dim, *args, **kwargs)
[docs]
def a_map(self, aerosol=None, location="equator", dim="latitude", cmap="BuPu", ax=None, *args, **kwargs):
"""Aerosols MMRs 2D map for a specific dimension & location (calls :func:`map_2D` with identical parameters)."""
mmr = self.aerosols[aerosol]["mmr"] * np.ones(self.shape)
vmin = max(np.min(mmr), 1e-16)
if vmin > 1e-16:
vmin = None
ax = self.map_2D(mmr, location=location, dim=dim, ax=ax, cmap=cmap, log=True, vmin=vmin, *args,
**kwargs)
index = get_index(self.grid, location, dim)
dim_display = dim
if self.vertical_in_pressure:
dim_display = "pressure"
ax.set_title("Aerosol MMR: log(%s) at %s %s" % (aerosol, dim_display, self.get_value_dim(index, dim)))
# self.zp_legend(ax, fig)
os.makedirs(os.path.join(self.out_folder, aerosol), exist_ok=True)
self.save_plot(os.path.join(aerosol, "a_map"), "%s_%s" % (dim, index))
[docs]
def a_maps(self, aerosols=None, dim="latitude", *args, **kwargs):
"""Aerosols Mass Mixing ratio 2D maps over multiple locations. See :func:`plot_2Dmap` for further parameters. You should set beforehand:
- :attr:`self.altitudes <altitudes>` when :attr:`dim = "altitude"` (default),
- :attr:`self.latitudes <latitudes>` when :attr:`dim = "latitude"`,
- :attr:`self.longitudes <longitudes>` when :attr:`dim = "longitude"`
"""
if aerosols is None:
aerosols = self.aerosols
if isinstance(aerosols, str):
aerosols = [aerosols]
for aerosol in aerosols:
self.plot_2D(self.a_map, aerosol=aerosol, dim=dim, *args, **kwargs)
[docs]
def plot_xprofile(self, *args, **kwargs):
return self.plot_x(*args, **kwargs)
[docs]
def plot_x(self, latitude=None, longitude=None, ax=None, title=None, figsize=(5, 3.5), *args, **kwargs):
"""Plot VMRs (gas mix profiles) of one vertical column."""
fig, ax, save = self.figure(ax, figsize)
if 'label' in kwargs:
del kwargs['label'] # will be molecule names
gas_legends = {}
mol_idx = 0
min_mix = 1
max_mix = 0
for mol_name, mix in self.gas_mix_ratio.items():
if mix == 'background':
others = list(self.gas_mix_ratio.values())
others.remove('background')
mix = 1 - np.sum(others)
if isinstance(mix, (np.ndarray)):
max_mix = max(max_mix, mix.max())
min_mix = min(min_mix, mix.min())
elif not isinstance(mix, (str)):
max_mix = max(max_mix, mix)
min_mix = min(min_mix, mix)
color = self.x_colors[mol_name]
self.plot_column(ax, mix, self.pressure, latitude=latitude, longitude=longitude, color=color, label=mol_name, *args, **kwargs)
gas_legends[mol_name] = mpl.lines.Line2D([0], [0], color=color, label=mol_name)
mol_idx += 1
plt.yscale('log')
plt.xscale('log')
min_mix = max(min_mix, 1e-12)
plt.xlim(min_mix, 1)
if save:
self.x_legend(ax=ax, fig=fig, title=title)
self.save_plot(self.save_column('mixratio'))
return gas_legends, min_mix, max_mix
[docs]
def x_legend(self, ax, fig, *args, **kwargs):
if isinstance(ax, (np.ndarray)):
self.legend2D(ax)
ax0 = ax.flatten()[0]
ax1 = ax.flatten()[-1]
else:
ax0 = ax
ax1 = ax
ax0.invert_yaxis()
self.set_title(ax0, *args, **kwargs)
plt.xlabel('Mixing ratio')
plt.ylabel('Pressure (Pa)')
plt.tight_layout()
h, labels = ax1.get_legend_handles_labels()
fig.subplots_adjust(left=0.2, right=0.82, wspace=0.25, hspace=0.35)
fig.legend(h, labels, loc='center right', bbox_to_anchor=(1, 0.5), ncol=1, prop={'size': 11},
frameon=False)
[docs]
def plot_xprofiles(self, *args, **kwargs):
"""Plot VMRs (gas mix profiles) of multiple columns. Set :attr:`self.latitudes <latitudes>` and :attr:`self.longitudes <longitudes>` for this beforehand."""
return self.plot_columns(self.plot_xprofile, name="mixratio", legend=self.x_legend, *args, **kwargs)
[docs]
def plot_tp(self, latitude=None, longitude=None, ax=None, title=None, figsize=(5, 3.5), *args, **kwargs):
"""Plot TP profile of one vertical column."""
fig, ax, save = self.figure(ax, figsize)
self.plot_column(ax, self.temperature, self.pressure, latitude=latitude, longitude=longitude, *args, **kwargs)
plt.yscale('log')
if save:
self.tp_legend(ax=ax, title=title)
self.save_plot(self.save_column("tp"))
[docs]
def tp_legend(self, ax, fig=None, *args, **kwargs):
if isinstance(ax, (np.ndarray)):
self.legend2D(ax)
ax = ax.flatten()[0]
ax.invert_yaxis()
self.set_title(ax, *args, **kwargs)
plt.xlabel('Temperature (K)')
plt.ylabel('Pressure (Pa)')
plt.tight_layout()
self.legend(ax)
[docs]
def plot_tps(self, *args, **kwargs):
"""Plot TP profiles of multiple columns. Set :attr:`self.latitudes <latitudes>` and :attr:`self.longitudes <longitudes>` for this beforehand."""
self.plot_columns(self.plot_tp, name="tp", legend=self.tp_legend, *args, **kwargs)
[docs]
def plot_zp(self, latitude=None, longitude=None, ax=None, title=None, figsize=(5, 3.5), *args, **kwargs):
"""Plot ZP profile of one vertical column."""
fig, ax, save = self.figure(ax, figsize)
self.plot_column(ax, self.pressure, self.z, latitude=latitude, longitude=longitude, *args, **kwargs)
ax.set_xscale('log')
if save:
self.zp_legend(ax=ax, title=title)
self.save_plot(self.save_column("zp"))
[docs]
def zp_legend(self, ax, fig=None, *args, **kwargs):
if isinstance(ax, (np.ndarray)):
self.legend2D(ax)
ax = ax.flatten()[0]
ax.invert_xaxis()
self.set_title(ax, *args, **kwargs)
plt.ylabel('Altitude ($10^6$m)')
plt.xlabel('Pressure (Pa)')
plt.tight_layout()
self.legend(ax)
[docs]
def plot_zps(self, *args, **kwargs):
"""Plot ZP profiles of multiple columns. Set :attr:`self.latitudes <latitudes>` and :attr:`self.longitudes <longitudes>` for this beforehand."""
self.plot_columns(self.plot_zp, name="zp", legend=self.zp_legend, *args, **kwargs)
[docs]
def plot_spectrum(self, mode=None, noise=True, ax=None, save=False, time=None, phase=None, resolution=None, xlabel=None, ylabel=None, dashes=[], linewidth=.5, x_axis="wls", figsize=(5.3, 3.5), legend=True, xlog=True, ylog=False, color=None, label=None, title=None, time_units = None, *args, **kwargs):
"""Plot a spectrum.
Args:
mode (str) : transmission/emission/lightcurve/phasecurve/None. None takes the value of the spectrum in the main model (transmission by default), which can be noised. Defaults to None.
noise (bool) : Plot noised spectrum. If it is set to a value, it overwrites the current noise using a normal distribution.
t (ndarray, optional): List of times to plot (in curve modes only) in seconds (or astropy).
resolution (int, optional): Number of points to bin to. Defaults to None.
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
save = True
suffix = None
phase, time, label_prefix, time_units = self.init_time(phase, time, time_units)
times = None
phases = None
try:
if phase is not None or time is not None:
phases = np.degrees(self.getattr(mode, "phases"))
try:
if time is not None:
times = self.getattr(mode, "times")
except Exception as e:
self.error(f"Did not compute times from phases. Maybe set self.model.orbit.period? Full error:\n{e}")
except AttributeError:
self.error(f"Mode {mode} has no 'phases' attribute. Maybe another mode?")
return 1
if phase is True and phases is not None:
phase = phases
elif phase is None:
phase = [None]
# INFO: Dirty fix to plot spectrum
self.ph_index = 0
phases = [np.NaN]
if isinstance(phase, (float,int,str)):
phase = [float(phase)]
for ph in phase:
try:
spectrum = self.flux(mode, phase=ph, resolution=resolution, noise=noise, ax=ax, color=color)
# computes also self.ph_index
assert spectrum is not None
assert not isinstance(spectrum, np.ndarray) # can be used for 2D plots, not spectrum
except:
self.error(f"plot_spectrum(): Mode '{mode}' has no 'spectrum' to plot (phase = {ph}). Maybe try another mode (lightcurve,emission,...).")
# TODO: iterate over mode to try to find a good one if mode is None
return 1
# 'ph_index' is set in self.flux()
if times is not None:
time, units = time_fmt(times[self.ph_index], time_units)
suffix = f"{time:.2f} {units}"
else:
suffix = f"{phases[self.ph_index]:.1f}°"
if len(phase) > 1:
p_label = f"{label} @ {suffix}"
if label is None:
p_label = f"{label_prefix} = {suffix}"
else:
p_label = label
if label is None:
p_label = self.label
ax.plot(getattr(spectrum, x_axis), spectrum.value, label=p_label, color=color, dashes=dashes, linewidth=linewidth, *args, **kwargs)
if xlog:
ax.set_xscale('log')
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
# ax.xaxis.set_minor_formatter(ticker.FormatStrFormatter('%.1f'))
if ylog:
ax.set_yscale("log")
if xlabel is None:
xlabel=label_from_dim(x_axis)
ax.set_xlabel(xlabel)
if ylabel is None:
ylabel = self.spectrum_label(mode)
ax.set_ylabel(ylabel)
if title is None:
title = self.title
if len(phase) < 2:
p_title = title
if title and suffix:
p_title = f"{title} @ {suffix}"
ax.set_title(p_title)
if save:
if legend:
ax.legend()
plt.tight_layout()
output_name = "spectrum"
if mode:
output_name += "_" + mode
self.save_plot(output_name)
[docs]
def transmission_spectrum(self, *args, **kwargs):
"""Calls :func:`plot_spectrum` with :attr:`mode` = 'transmission'."""
return self.plot_spectrum(mode="transmission", *args, **kwargs)
[docs]
def emission_spectrum(self, *args, **kwargs):
"""Calls :func:`plot_spectrum` with :attr:`mode` = 'transmission'."""
return self.plot_spectrum(mode="emission", *args, **kwargs)
[docs]
def transmittance_map(self, wl=None, wn=None, phase=None, mode="transmission", zmax=None, r_factor=None, ax=None, title="Transmittance at ", cmap="gnuplot", overlay=True, star_out=True, figsize=None, save_name="transmittances/transmittance", pcolormesh=False, core_color="black", *args, **kwargs):
"""Plot a map of the transmittance, as seen by the observer.
Args:
wl (float, optional): Select wavelength (or inferior). Defaults to None.
wn (float, optional): Select wavenumber. Defaults to None.
phase (float, optional): Select phase (in lightcurve mode). Defaults to None.
mode (str, optional): transmission/lightcurve. Defaults to "transmission".
zmax (float, optional): Truncate plot at altitude :attr:`zmax`, scaled using :attr:`h_unit`. Defaults to max altitude.
r_factor (float, optional): Scale planet core radius (Rp). Can be used to enlarge (artificially) the atmosphere. Defaults to 1.
cmap (str, optional): Colormap. Defaults to "gnuplot".
overlay (bool, optional): Activates ticks and grid. Defaults to True.
star_out (bool, optional): Hide part of the transmittance that is out of the star. Defaults to True.
save_name (str, optional): Change name base of output file. Defaults to "transmittance".
pcolormesh (bool, optional): Activates use of pcolormesh. Otherwise use contourf. Defaults to False.
"""
if phase is True:
try:
mode="lightcurve"
phase = np.degrees(self.mode(mode).phases)
except:
phase = 0
if isinstance(phase, (np.ndarray, list)):
for ph in np.float_(phase):
self.transmittance_map(wl=wl, wn=wn, phase=ph, mode=mode, zmax=zmax, r_factor=r_factor, ax=ax, title=title, cmap=cmap, overlay=overlay, star_out=star_out, figsize=figsize, save_name=f"{save_name}_{ph:.3f}", pcolormesh=pcolormesh, *args, **kwargs)
return
if phase is not None:
phase = np.radians(float(phase))
mode = "lightcurve" # force lightcurve mode
ws, wl, wn, w_units = self.init_spectral(wl, wn, default_wl=1, mode=mode)
if zmax is None:
zmax = self.zmax
if r_factor is not None:
self.r_factor = r_factor
if hasattr(wl, '__len__'):
if len(wl) > 1:
fig = plt.figure(figsize=figsize)
# iterate over wavelengths
ncols = int((len(wl)+1)/2)
nrows = int((len(wl)+1)/ncols)
axes = []
for i, wavelength in enumerate(wl):
ax = fig.add_subplot(nrows, ncols, i+1, polar=True)
axes.append(ax)
cs = self.transmittance_map(wl=float(wavelength), phase=phase, mode=mode, zmax=zmax, r_factor=r_factor, ax=ax, title="", cmap=cmap, overlay=overlay, star_out=star_out, figsize=figsize, save_name=save_name, pcolormesh=pcolormesh, *args, **kwargs)
fig.subplots_adjust(right=1)
clipped_colorbar(cs, format='%.3f', ax=axes)
self.save_plot(save_name)
return cs
wl = wl[0]
save = False
if ax is None:
save = True
try:
tr = self.transmittance(phase, wl)
assert tr is not None
except Exception as e:
self.warning(f"No transmittance to plot ('{mode}' mode). If you need it, set store_transmittance(s) to True (see parameters for each module). The exact error is:\n{e}")
return
try:
if mode == "lightcurve" and phase is not None and star_out:
# we need to re-calculate intersection of transmittance and star since we did NOT store it for every phase (too costly)
self.mode(mode).transmittance_surfaces = False # for plots, cells are either ENTIRELY in front the star, or are not
tr, dist = np.subtract(1, self.mode(mode).star_rays_opacity(phase, np.subtract(1, tr)))
except Exception as e:
self.warning("Failed to compute star 'shadow' over transmittance. Skipping and plotting transmittance as is.")
r = self.rays(mode).r[::-1]/self.h_unit
z = r - self.Rp
r = z + self.R # scaling: R = Rp * r_factor
z_idx = np.where(z < zmax)
z = z[z_idx]
r = r[z_idx]
tr = tr[::-1][z_idx]
try:
assert self.rays(mode).angles_limits[0]
except:
# rays were not properly written in h5 so we compute them again
self.rays(mode).n_radial = tr.shape[0]
self.rays(mode).n_angular = tr.shape[1]
self.rays(mode).build(self.model)
th = self.rays(mode).angles
if ax is None:
fig = plt.figure(figsize=(5,3.5))
ax = fig.add_subplot(111, projection='polar')
repeats = int(np.ceil(180/len(th)))
tr = np.repeat(tr, repeats, axis=1) # smoothing things
th_0 = self.rays(mode).angles_limits[0]
th = np.linspace(th_0, th_0+2*np.pi, tr.shape[1])
x, y = np.meshgrid(th, r)
if pcolormesh:
cs = ax.pcolormesh(x, y, tr, cmap=cmap, vmin=0, vmax=1, *args, **kwargs)
else:
cs = ax.contourf(x, y, tr, cmap=cmap, levels=20, vmin=0, vmax=1, *args, **kwargs)
if mode == "lightcurve" and phase is not None and star_out:
try:
Rs = self.model.star.radius/self.h_unit
rS, aS = self.mode("lightcurve").star_projected_coordinates(phase)
rS /= self.h_unit
circle = pl.Circle((rS*np.sin(-aS), rS*np.cos(aS)), Rs, transform=(Affine2D().rotate(ax._theta_offset.get_matrix()[0, 2]) + ax.transProjectionAffine + ax.transAxes), color="black", alpha=0.3)
ax.add_artist(circle)
except:
pass
ax.set_theta_zero_location("N")
ax.set_rmin(0)
if np.isfinite(zmax):
ax.set_rmax(self.R + zmax)
ticks = [0, -1]
if self.r_factor <1:
ticks = [0, int(len(z)/2),-1]
ax.set_rgrids(r[ticks])
ax.set_yticklabels(["{:,.1f}".format(x) + ' Mm' for x in z[ticks[:-1]]]+["0"], fontsize=8)
ax.grid(linewidth=1)
if overlay:
ax.set_xticklabels(["0°","45°","90°","135","180°","225°","270°","315°"], fontsize=9)
else:
ax.set_xticklabels([])
ax.set_rlabel_position(0)
core = pl.Circle((0, 0), self.R, transform=(Affine2D().rotate(ax._theta_offset.get_matrix()[0, 2]) + ax.transProjectionAffine + ax.transAxes), color=core_color, alpha=1)
ax.add_artist(core)
if overlay:
ax.set_title(f"{title}{ws[self.w_index]:.2f} {units_from_dim(w_units)}", pad=15)
plt.tight_layout(pad=1)
if save:
if overlay:
fig.colorbar(cs, ticks=ticker.LinearLocator(11), format='%.1f', pad=.1)
self.save_plot(save_name)
return cs
[docs]
def transmittance_maps(self, wl=None, wn=None, save_name="transmittances/transmittance", *args, **kwargs):
"""Same parameters as :func:`transmittance_map` but create new files for each wl/wn."""
if wl is not None:
if isinstance(wl, (float, int) ):
wl = [wl]
for w in wl:
self.transmittance_map(wl=w, save_name=f"{save_name}_{w:.3f}", *args, **kwargs)
elif wn is not None:
for w in wn:
self.transmittance_map(wn=w, save_name=f"{save_name}_{w:.3f}", *args, **kwargs)
[docs]
def transmittance_animation(self, wl=None, wn=None, phase=None, filename=None, prefix="transmittances/transmittance", *args,
**kwargs):
"""Transforms maps generated by :func:`transmittance_map` to a GIF animation.
"""
try:
from wand.exceptions import WandError
from wand.image import Image
except ImportError as e:
self.error(f"Wand or MagickWand not installed. Cannot GIFify. Full error:\n{e}")
return
if phase is None or phase is True:
try:
phase = np.degrees(self.mode("lightcurve").phases)
except AttributeError:
self.debug("transmittance_animation: 'lightcurve' mode has not been computed.")
return
if isinstance(phase, (float, int, str)):
phase = [phase]
suffix=self.suffix
if wl is not None:
prefix=f"{prefix}_{wl:.3f}"
suffix=f"{suffix}_{wl:.3f}"
if wn is not None:
prefix=f"{prefix}_{wn:.3f}"
suffix=f"{suffix}_{wn:.3f}"
gif_output = Image()
for i in phase:
try:
frame = Image(filename=f"{self.out_folder}/{prefix}_{i:.3f}_{self.suffix}.pdf")
gif_output.sequence.append(frame)
except:
pass
for i, frame in enumerate(gif_output.sequence):
frame.delay = 20
try:
gif_output.type = 'optimize'
except WandError as e:
self.critical("No transmittance images found. Cannot generate GIF.")
return
if filename is None:
filename = "%s/transmittances_%s.gif" % (self.out_folder, suffix)
gif_output.save(filename=filename)
print("Saved %s" % filename)
[docs]
def emission_map(self, wl=None, wn=None, mode="emission",
ax=None, title="Emission at ", cmap="gnuplot",
overlay=True, figsize=(5, 3.5),
*args, **kwargs):
"""Plot emission at a specific wavelength :attr:`wl` (or closest inferior wavelength, in micrometer) or wavenumber.
"""
ws, wl, wn, w_units = self.init_spectral(wl, wn, default_wl=1, mode=mode)
if hasattr(wn, '__len__'):
if len(wn) > 1:
fig = plt.figure(figsize=figsize)
# iterate over wavelengths
nrows = int((len(wn) + 1) / 2)
ncols = int((len(wn) + 1) / nrows)
axes = []
try:
i = self.find_spectral(wl, wn)
# i = nmin(self.wns.searchsorted(wn), len(self.wns) - 1)
vmin = self.mode(mode).raw_flux[..., i].min()
vmax = self.mode(mode).raw_flux[..., i].max()
except:
self.warning(
"No raw emission flux to plot. If you need it, set 'store_raw_flux' to True.")
return
for i, w in enumerate(wn):
ax = fig.add_subplot(nrows, ncols, i + 1)
axes.append(ax)
cs = self.emission_map(wn=float(w), ax=ax, title="", cmap=cmap, vmax=vmax, vmin=vmin,
*args, **kwargs)
fig.subplots_adjust(right=1)
cbar = clipped_colorbar(cs, format='%.3f', ax=axes)
cbar.ax.tick_params(labelsize=7)
self.save_plot("emission")
return
wn = wn[0]
save = False
if ax is None:
save = True
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.set_xlabel('East Longitude')
ax.set_ylabel('Latitude')
try:
w_index = self.find_spectral(wl, wn)
flux = self.mode(mode).raw_flux[..., w_index]
flux = np.concatenate([flux, flux[:, 0:1]], axis=1)
except:
self.warning("No raw emission flux to plot. If you need it, set 'store_raw_flux' to True.")
return
longitudes = np.concatenate((self.grid.mid_longitudes, self.grid.mid_longitudes[0:1] + 2 * np.pi))
x, y = np.degrees(np.meshgrid(longitudes, self.grid.mid_latitudes))
locator = ticker.LogLocator(base=1.01, subs=(1.0,), numticks=100)
locator = ticker.LinearLocator(100)
cs = ax.contourf(x, y, flux, locator=locator, *args, **kwargs)
if overlay:
ax.set_title(f"{title}{ws[w_index]:.3f} {units_from_dim(w_units)}", pad=15)
plt.tight_layout(pad=1)
if save:
fig.colorbar(cs, format=ticker.LogFormatter(1.01, labelOnlyBase=False), pad=.1)
self.save_plot("emission")
return cs
[docs]
def plot_emission(self, *args, **kwargs):
print("This function will not be supported in future releases. Please use emission_map() instead.")
[docs]
class CurvePlot(BasePlot):
"""Plots for (light/phase)curves."""
[docs]
def plot_curve(self, mode="phasecurve", wl=None, wn=None, ax=None,
label=None, title="Phasecurve",
x_axis: Literal["times","phases"]="times",
xlabel=None, x_units=None, ylabel='Normalized flux',
legend=True, figsize=(5, 3.5), *args, **kwargs):
"""Plot a curve from a `mode` (phasecurve/lightcurve). Use either `wl` or `wn`, not both.
Args:
mode (str, optional): phasecurve of lightcurve. Defaults to "phasecurve".
wl (float, optional): Wavelength of the curve (can be a list). Defaults to 15.
wn (float, optional): Wavenumber of the curve (can be a list).
"""
fig, ax, save = self.figure(ax, figsize)
ws, wl, wn, w_units = self.init_spectral(wl, wn, mode=mode)
x = self.get_spectral(w_units=w_units, mode=mode)[0] # wls or wns
suffix = ""
if isinstance(wn, (float, int)):
wn = [wn]
try:
x, xlabel_ = self.x_axis_curve(x_axis=x_axis, x_units=x_units, mode=mode)
if xlabel is None: xlabel = xlabel_
except Exception as e:
self.info(f"plot_curve(): no {mode} mode to plot. Full error:\n{e}")
return 1
for w in wn:
curve = self.curve(mode, wn=w)
if self.substellar_longitude is not None:
curve = curve[::-1] # why?
# self.w_index comes from curve(), which calls find_spectral()
suffix = f"{ws[self.w_index]:.1f} {units_from_dim(w_units)}"
p_label = label
if len(wn) > 1:
p_label = f"{suffix}"
if label:
p_label = f"{label} @ {suffix}"
ax.plot(x, curve, label=p_label, *args, **kwargs)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if legend and label is not None:
ax.legend()
if len(wn) < 2:
p_title = f"{suffix}"
if title and suffix:
p_title = f"{title} @ {suffix}"
ax.set_title(f"{p_title}")
plt.tight_layout(pad=1)
if save:
self.save_plot(mode)
[docs]
def plot_phasecurve(self, *args, **kwargs):
"""See parameters of :func:`plot_curve`."""
return self.plot_curve(mode="phasecurve", title="Phasecurve", *args, **kwargs)
[docs]
def plot_lightcurve(self, wl=1, *args, **kwargs):
"""See parameters of :func:`plot_curve`."""
return self.plot_curve(mode="lightcurve", title="Lightcurve", wl=wl, *args, **kwargs)
[docs]
def plot_2d_flux(self, mode="phasecurve", ax=None, title='Normalized flux',
figsize=(5, 3.5), x=None, y=None, flux=None, x_axis="wls", xlog=True,
xlabel=None, ylabel='Phase angle (degrees)',
cmap='viridis', colorbar_kwargs=dict(format='%.3f'), savename=None,
*args, **kwargs):
"""2D phase curves (imshow), with the X axis the spectral dimension and the Y axis phases.
Args:
x_axis (str) : Choose X axis as "wls" or "wns", for wavelengths or wavenumbers, respectively.
figsize (tuple, optional): Size of the figure. Defaults to (5,3.5).
cmap (str, optional): Colormap to use in the imshow(). Defaults to "gnuplot".
"""
fig, ax, save = self.figure(ax, figsize)
if flux is None:
try:
flux = self.curve(mode)
except:
self.info("plot_curves(): no %s to plot." % mode)
return 1
if x is None:
x = self.get_spectral(w_units=x_axis, mode=mode)[0] # wls or wns
if y is None:
y = np.degrees(self.mode(mode).phases)
if flux.max()==1:
# plot a normalized flux
levels=np.linspace(flux.min(), flux.max(), 8)
else:
# plot a difference
l = [-200,-100,-50,-20,-5,0,5,20,50,100,200,700] # in ppm
if flux.max()<1:
l = np.array(l)/1e6 # not in ppm
levels=np.array([flux.min(),*l,flux.max()])
levels = levels[np.where((levels>=flux.min()) & (levels<=flux.max()))]
colors = [mpl.cm.get_cmap(cmap)(i/len(levels)) for i in range(len(levels))]
cmap = mpl.cm.get_cmap(cmap)
norm = mpl.colors.BoundaryNorm(levels, ncolors=cmap.N, clip=True)
cs = ax.pcolormesh(x, y, flux, cmap=None, norm=norm, zorder=-9)
ax.set_rasterization_zorder(-1)
if xlabel is None:
xlabel = label_from_dim(x_axis)
if xlog:
ax.set_xscale('log')
else:
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
plt.tight_layout(pad=1)
if save:
fig.colorbar(cs, ticks=levels, spacing="uniform", **colorbar_kwargs)
plt.tight_layout(pad=1)
if savename is None:
savename = mode + "s"
self.save_plot(savename)
return cs
[docs]
def plot_2d_phasecurve(self, *args, **kwargs):
return self.plot_2d_flux(mode="phasecurve", *args, **kwargs)
[docs]
def plot_2d_lightcurve(self, *args, **kwargs):
return self.plot_2d_flux(mode="lightcurve", *args, **kwargs)
[docs]
class Plot(LoadPlot, ModelPlot, CurvePlot):
"""Plot class. Can plot a Model() or a HDF5 file generated by Pytmosph3R."""
def __init__(self, model: Union[Model, str] = None,
title: Optional[str] = None, label: Optional[str] = None,
suffix=None, out_folder: str = '.', cmap: str = 'Paired',
r_factor=1., h_unit=1e6, zmax=np.inf, pmin=None, substellar_longitude=None,
vertical_in_pressure=None,
interactive=None, *args, **kwargs):
"""Parameters for the Plots:
Args:
model (string or :class:`~pytmosph3r.model.model.Model`) : HDF5 filename from which to read the model (if string), or Model after its computation.
label (str) : Most useful when comparing multiple plots.
suffix (str) : Suffix to append to plot filenames.
out_folder (str) : Directory where we will generate the plots.
r_factor (float) : Factor with which the planet radius will be scaled. (Below 1, the radius is smaller, so the atmosphere looks larger). Defaults to 1.
zmax (float) : Maximum height to plot (can be used to crop the atmosphere), scaled via :attr:`h_units`. Defaults to infinity.
h_unit (float) : Units of the heightscale. Defaults to 1E6 (= Mm).
interactive (bool) : Activate/deactivate showing plots. Plot.interactive can also be changed for all Plot objects.
"""
super().__init__(self.__class__.__name__, *args, **kwargs)
self.h_unit = h_unit
"""Height unit scaling. By default 1e6, i.e., Mm."""
self.zmax = zmax
"""Max altitude (in Mm) to plot."""
self.r_factor = r_factor
"""Radius factor (for visual purposes). By default 1."""
self._p_min = pmin
"""Min (top) pressure to plot."""
self.vertical_in_pressure = vertical_in_pressure
"""Use pressure as vertical axis ."""
self.substellar_longitude = substellar_longitude
"""Longitude of the substellar point (in degrees)."""
if substellar_longitude is not None:
self.substellar_longitude = float(substellar_longitude)
self.f = None
if isinstance(model, str):
if os.path.splitext(model)[-1] != ".h5":
try:
new_path = os.path.join(model, "output_pytmosph3r.h5")
if os.path.isfile(new_path):
model = new_path
else:
raise NameError
except:
self.warning("Input file (%s) extension unrecognized. Not .h5?" % model)
# self.f = h5py.File(model,'r')
self.f = HDF5Input(model)
self.filename = model
else:
self._model = model
if interactive in (False, True):
# user can set Plot.interactive to change all plots behavior, hence the 'None' by default.
self.interactive = interactive
if not self.interactive:
mpl.use('Agg')
self.title = title
self.cmap = mpl.cm.get_cmap(cmap)
self.suffix = suffix
if self.suffix is None:
self.suffix = "pytmosph3r"
self.out_folder = out_folder
self.ph_index = None
if not os.path.exists(self.out_folder):
os.makedirs(self.out_folder)
if label is None:
if isinstance(model, str):
label = path_leaf(model)
elif model is not None and 'filename' in model.__dict__ and model.filename:
label = path_leaf(model.filename)
else:
label = "Pytmosph3R"
self.label = label
self.p_id = label
@property
def idx_latitude(self):
return get_latitude_index(self.latitude, self.n_latitudes)
@property
def idx_longitude(self):
return get_longitude_index(self.longitude, self.n_longitudes)