from typing import Union
import datetime
import inspect
import sys
import warnings
from functools import wraps
from typing import Literal, Optional
import astropy.units as u
import numpy as np
[docs]
def get_attributes(obj, attr=None):
"""Returns attributes of an object if they do not start nor end with an underscore.
"""
if isinstance(obj, dict):
return obj.items()
attrs = dir(obj)
items = []
for a in attrs:
try:
if (not (a.startswith('_') or a.endswith('_'))) and ((attr is None) or (a in attr)):
value = getattr(obj, a)
if not inspect.ismethod(value):
items.append([a, value])
except:
continue
return items
[docs]
def get_methods(obj):
"""Returns methods of an object."""
f = inspect.getmembers(obj.__class__, lambda a: inspect.isroutine(a))
return [m[0] for m in f if not (m[0].startswith('_') or m[0].endswith('_'))]
[docs]
def retrieve_name(var, up=0):
"""Retrieve the name of a variable :py:attr:`var` as a string. For example if :py:attr:`var` is named \'temperature\', the function will return \"temperature\". Example::
def test(hello, up=0):
return retrieve_name(hello, up)
bonjour = \"bonjour\"
test(bonjour) # returns \"hello\"
test(bonjour, up=1) # returns \"bonjour\"
Args:
var: Variable
up: context to retrieve from (default is the current context, i.e., the function in which the function has been called)
"""
callers_local_vars = None
if up == 0:
callers_local_vars = inspect.currentframe().f_back.f_locals.items()
elif up == 1:
callers_local_vars = inspect.currentframe().f_back.f_back.f_locals.items()
elif up == 2:
callers_local_vars = inspect.currentframe().f_back.f_back.f_back.f_locals.items()
return [var_name for var_name, var_val in callers_local_vars if var_val is var]
[docs]
def to_radians(array, units="deg"):
"""Convert an array from 'units' to rad."""
if isinstance(array, u.Quantity):
return array.to_value(u.rad) # ignore "units" if Quantity
try:
to_rad = u.Unit(units).to(u.rad)
except:
to_rad = u.Unit(units, format="cds").to(u.rad)
return to_rad * array
[docs]
def to_SI(obj: object, unit:Union[u.Unit, str]=None):
"""Convert astropy quantity to 'units' (SI by default)."""
if obj is None:
return obj
if isinstance(obj, u.Quantity):
if unit is None:
# No unit passed, use Quantity `si` attribute to convert the Quantity to the correct unit
return obj.si.value
# `unit` is provided, use it
return obj.to_value(unit)
# INFO: the unit should be attached to the container, ie: [1, ..., 3]*u.m, not [1*u.m, ..., 3*u.m]
if isinstance(obj, (tuple, np.ndarray, list)):
return [to_SI(x, unit) for x in obj]
return obj
[docs]
def get_altitude_index(altitude: Literal['surface', 'middle', 'top'], n_vertical: int):
"""Return altitude index associated with keyword `altitude` (within \"surface\",\"top\",\"middle\"). If not among this keyword, return as is.
"""
if altitude == "surface":
return 0
if altitude == "top":
return n_vertical - 1
if altitude == "middle":
return int(n_vertical / 2)
if not isinstance(altitude, (int, list, np.ndarray)):
raise TypeError("Altitude %s not recognized." % altitude)
# TODO: check usefulness of this return, ie should we allow only allow literal.
return altitude
[docs]
def get_latitude_index(latitude: Literal['north', 'equator', 'south'], n_latitudes: int):
"""Return latitude index associated with keyword `latitude` (within \"north\",\"pole\",\"equator\").
"""
if latitude == "north":
return 0
elif latitude == "south":
return n_latitudes - 1
elif latitude == "equator":
return int(n_latitudes / 2)
# TODO: check usefulness of this return, ie should we allow only allow literal.
try:
latitude = int(latitude)
except TypeError:
raise TypeError("Latitude %s not recognized." % latitude)
return latitude
[docs]
def get_longitude_index(longitude: Literal['day', 'night', 'terminator'], n_longitudes=None):
"""Return longitude index associated with keyword `longitude` (within \"day\",\"night\",
\"terminator\"). The keywords refer to the position of the star if the direction of the rays has been
defined as (latitude, longitude) = (0,0).
"""
if longitude == "day":
return int(n_longitudes / 2)
if longitude == "night":
return 0
if longitude == "terminator":
return int(n_longitudes / 4)
try:
longitude = int(longitude)
except TypeError:
raise TypeError("Longitude %s not recognized." % longitude)
return longitude
[docs]
def get_index(array, location, dim: Literal['altitude', 'latitude', 'longitude'] = 'altitude'):
"""Return the index of `location` in `array` at the dimension `dim` (altitude/latitude/longitude). For
example, if `dim` is 'altitude' and `location` is 'surface', it will return 0. """
if isinstance(array, (float, str)) or array.ndim != 3:
warnings.warn("%s doesn't have 3 dimensions" % type(array))
return array
if dim == "altitude":
return get_altitude_index(location, array.shape[0])
if dim == "latitude":
return get_latitude_index(location, array.shape[1])
if dim == "longitude":
return get_longitude_index(location, array.shape[2])
warnings.warn(
f"Dimension '{dim}' not recognized. Should be among 'altitude', 'latitude' or 'longitude'. "
f"Returning 0...")
return 0
[docs]
def get_2D(array, location, dim: Literal['altitude', 'latitude', 'longitude'] = 'altitude'):
"""Return a 2D slice of array `array` at the dimension `dim` (among altitude, latitude and longitude).
For example if `dim` is 'altitude', it will return the 2D array of all latitudes and longitudes at this
altitude. """
if isinstance(array, (float, str)) or array.ndim != 3:
warnings.warn("%s doesn't have 3 dimensions" % type(array))
return array
if dim == "altitude":
return array[get_altitude_index(location, array.shape[0])]
if dim == "latitude":
return array[:, get_latitude_index(location, array.shape[1])]
if dim == "longitude":
return array[:, :, get_longitude_index(location, array.shape[2])]
warnings.warn(
f"Dimension '{dim}' not recognized. Should be among 'altitude', 'latitude' or 'longitude'. "
f"Returning whole array but the program will be probably fail...")
return array
[docs]
def get_column(array, latitude, longitude):
"""Return a vertical column of array `array` at the position (`latitude`, `longitude`).
"""
if isinstance(array, (float, str)) or array.ndim != 3:
return array
return array[:, get_latitude_index(latitude, array.shape[1]),
get_longitude_index(longitude, array.shape[2])]
[docs]
def convert_log(array, units: Optional[Literal['log', 'ln']] = None):
"""Convert :attr:`array` from log space to normal space. :attr:`units` determines if the space is log
or ln. """
if units == 'log':
return np.power(10, array)
if units == 'ln':
return np.exp(array)
warnings.warn(f'Unable to determine the kind of log. `units` should be set to `log` or `ln`, provided: `{units}`.')
return array
[docs]
def mol_key(mol_dict, mol, mol_type="vap", data=""):
"""Returns the key corresponding to `mol` in `mol_dict`.
"""
key = mol + data
if mol_dict is not None and key in mol_dict.keys():
return mol_dict[key]
return mol.lower() + "_" + mol_type + data
[docs]
def aerosols_array_iterator(dictionary):
"""Returns an iterator over arrays of an aerosols dictionary.
For example, if the dictionary looks like this:
:code:`{'H2O':{'mmr':np.array([1, 2]), 'reff':1e-5}}`.
The code will iterate over the mmr (array) but not reff (float).
"""
for element, value in dictionary.items():
for key_element, element_val in value.items():
if isinstance(element_val, np.ndarray):
yield element, key_element
[docs]
def arrays_to_zeros(dictionary, shape):
"""Returns a copy of `dictionary` of which subarrays are initialized as an array of shape `shape`
filled with zeros. Used for aerosols, to initialize the arrays before interpolating.
"""
new_dict = dictionary.copy()
for element, value in dictionary.items():
new_dict[element] = value.copy()
for key_element, element_val in value.items():
if isinstance(element_val, np.ndarray):
new_dict[element][key_element] = np.zeros(shape)
return new_dict
[docs]
def init_array(obj, size):
"""Returns `obj` if float, else array of size `size`."""
if isinstance(obj, float):
return obj
else:
return np.full(size, np.nan)
[docs]
def get(obj, i):
"""Returns obj if float, else return the value of obj at `i`"""
try:
return obj[i]
except:
return obj
[docs]
def update_dict(d, u):
"""Recursive update of nested dictionaries."""
for k, v in u.items():
if isinstance(v, dict):
d[k] = update_dict(d.get(k, {}), v)
else:
d[k] = v
return d
[docs]
def merge_attrs(obj, other):
"""Merge attributes from `other` into `obj` if they do not exist/are None.
"""
try:
if isinstance(other, dict):
obj.__dict__.update(
{k: other[k] for k in other if (not hasattr(obj, k) or getattr(obj, k) is None)})
else:
obj.__dict__.update({k: getattr(other, k) for k in other.__dict__ if
(not hasattr(obj, k) or getattr(obj, k) is None)})
except:
pass # maybe `other` is not well-constructed/does not exist
return obj
[docs]
def spectral_chunks(k_database, n):
chunk_size = len(k_database.wns) / n
wn_ranges = []
for chunk in range(n):
try:
wn_range = [k_database.wnedges[int((chunk * chunk_size) - 1)],
k_database.wnedges[int(min(len(k_database.wnedges) - 1, (chunk + 1) * chunk_size))]]
if chunk == 0:
wn_range[0] = -1
wn_ranges.append(wn_range)
except IndexError:
pass # Outside of wns range now
return wn_ranges
[docs]
def get_chunk(i, n, size):
"""Get i-th chunk out `n` chunks dividing `size`."""
if hasattr(size, "__len__") and len(size) == 2: # 2 dimension
chunk_size = size[0] * size[1] / n
idx = int(i * chunk_size), int(min(int((i + 1) * chunk_size), size[0] * size[1]))
chunk_0 = [idx[0] % size[0], idx[1] % size[0]] # chunk in 1st dimension
chunk_1 = [int(idx[0] / size[0]), int(idx[1] / size[0])] # chunk in 2nd dimension
if idx[1] == size[0] * size[1]:
chunk_0[1] = size[0]
chunk_1[1] = size[1] - 1
chunk = [chunk_0, chunk_1]
else:
chunk_size = size / n
chunk = [int(i * chunk_size), min(int((i + 1) * chunk_size), size)]
return chunk
[docs]
def get_chunk_size(chunk, chunk_size, total_size):
start = chunk * chunk_size # start of chunk
end = min((chunk + 1) * chunk_size, total_size) # last chunk may be shorter
return end - start
[docs]
def get_wls(array, wls, array_wls):
"""Get subset of :attr:`array` at wavelengths :attr:`wls` when the wavelengths of :attr:`array` are
:attr:`array_wls`
"""
return array[..., -np.asarray(array_wls[::-1]).searchsorted(wls)]
[docs]
def make_array(lis):
"""Get array from a list of lists. Missing data is replaced with -1."""
n = len(lis)
lengths = np.array([len(x) for x in lis])
max_len = np.max(lengths)
try:
shape_element = lis[0][0].shape
arr = np.full((n, max_len) + shape_element, -1.)
except:
arr = np.full((n, max_len), -1.)
for i in range(n):
arr[i, :lengths[i]] = lis[i]
return np.array(arr)
[docs]
def timer(f):
"""A decorator to time a function in debug mode."""
@wraps(f)
def wrapper(*args, **kwargs):
start_time = datetime.datetime.now()
result = f(*args, **kwargs)
end_time = datetime.datetime.now()
total_time = end_time - start_time
args[0].debug("%s run in %.2fs" % (f.__name__, total_time.total_seconds()))
return result
return wrapper
[docs]
def query_yes_no(question, default="yes"):
"""Ask a yes/no question via raw_input() and return their answer.
"question" is a string that is presented to the user.
"default" is the presumed answer if the user just hits <Enter>.
It must be "yes" (the default), "no" or None (meaning
an answer is required of the user).
The "answer" return value is True for "yes" or False for "no".
"""
valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
if default is None:
prompt = " [y/n] "
elif default == "yes":
prompt = " [Y/n] "
elif default == "no":
prompt = " [y/N] "
else:
raise ValueError("invalid default answer: '%s'" % default)
while True:
sys.stdout.write(question + prompt)
try:
choice = input().lower()
except EOFError:
return valid[default]
if default is not None and choice == "":
return valid[default]
elif choice in valid:
return valid[choice]
else:
sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")