import xsimlab as xs
import attr
from attr import fields_dict
from collections import OrderedDict, defaultdict, Counter
from functools import wraps
import inspect
import numpy as np
from .variables import XSOVarType
from .backendcomps import FirstInit, SecondInit, ThirdInit, FourthInit, FifthInit
[docs]def _create_variables_dict(process_cls):
"""Get all phydra variables declared in a component.
Exclude attr.Attribute objects that are not XSO specific.
"""
return OrderedDict(
(k, v) for k, v in fields_dict(process_cls).items() if "var_type" in v.metadata
)
[docs]def _convert_2_xsimlabvar(var, intent='in',
var_dims=None, value_store=False, groups=None,
description_label='', attrs=True):
"""Converts XSO variables to xarray-simlab variables to be used in the model backend.
Function receives variable as attr in _make_xsimlab_vars function and extracts
description, dimensions and metadata, then passes it and additional arguments
through xarray-simlab's xs.variable function.
Parameters
----------
var : attr._Make.Attribute
XSO variable defined in object decorated with xso.component()
intent : str ('in' or 'out')
passed along, defines variable as receiving input from another component
or being initialized within this component.
var_dims : tuple or str
Defines dimensionality of created xsimlab variable,
can be singular string or tuple of strings.
value_store : bool
When true, the ouput of variable is stored to xsimlab variable. Adds 'time' dimension.
groups : str
When defined, the variable output can be referenced via xarray simlabs
group variable in other XSO components.
description_label : str
Description stored with Xarray Dataset created by xsimlab.
attrs : bool
If true, attrs defined in variable metadata are passed along to xsimlab variable function.
Returns
-------
xs.variable
attr class handled by Xarray-simlab, the functional foundation of XSO
"""
# get variable metadata
var_description = var.metadata.get('description')
if var_description:
description_label = description_label + var_description
if var_dims is None:
var_dims = var.metadata.get('dims')
# initialize dimensions, with time if value_store is true
if value_store:
if not var_dims:
# if there is no dim supplied, we need 'time' as
var_dims = 'time'
elif 'time' in var_dims:
pass
elif isinstance(var_dims, str):
var_dims = (var_dims, 'time')
elif isinstance(var_dims, tuple):
var_dims = var_dims + ('time',)
elif isinstance(var_dims, list):
_dims = []
for dim in var_dims:
if isinstance(dim, str):
_dims.append((dim, 'time'))
else:
_dims.append((*dim, 'time'))
var_dims = _dims
else:
raise ValueError("Failed to parse dims argument for variable of type:",
var.metadata["var_type"], "with description:", description_label,
"with dimensions:", var_dims)
if var_dims is None:
var_dims = ()
if attrs:
var_attrs = var.metadata.get('attrs')
else:
var_attrs = {}
# return fully functional xarray-simlab variable
return xs.variable(intent=intent, dims=var_dims, groups=groups,
description=description_label, attrs=var_attrs)
[docs]def _make_xso_variable(label, variable):
"""Checks for type of variable defined and calls _convert_2_xsimlabvar function
accordingly. Returns dict with label and xsimlab variable as key/value pairs.
"""
xs_var_dict = defaultdict()
if variable.metadata.get('foreign') is True:
list_input = variable.metadata.get('list_input')
if list_input:
var_dims = variable.metadata.get('dims')
if var_dims is None:
raise ValueError("Variable with list_input=True requires passing dimension to dims keyword argument")
xs_var_dict[label] = _convert_2_xsimlabvar(var=variable, var_dims=var_dims,
description_label='label list / ')
else:
xs_var_dict[label] = _convert_2_xsimlabvar(var=variable, var_dims=(),
description_label='label reference / ')
elif variable.metadata.get('foreign') is False:
xs_var_dict[label + '_label'] = _convert_2_xsimlabvar(var=variable, var_dims=(),
description_label='label / ')
xs_var_dict[label + '_init'] = _convert_2_xsimlabvar(var=variable, description_label='initial value / ')
xs_var_dict[label] = _convert_2_xsimlabvar(var=variable, intent='out',
value_store=True,
description_label='output of variable / ')
return xs_var_dict
[docs]def _make_xso_parameter(label, variable):
"""Checks for type of variable defined and calls _convert_2_xsimlabvar function
accordingly. Returns dict with label and xsimlab variable as key/value pairs.
"""
xs_var_dict = defaultdict()
xs_var_dict[label] = _convert_2_xsimlabvar(var=variable, description_label='parameter / ')
return xs_var_dict
[docs]def _make_xso_forcing(label, variable):
"""Checks for type of variable defined and calls _convert_2_xsimlabvar function
accordingly. Returns dict with label and xsimlab variable as key/value pairs.
"""
xs_var_dict = defaultdict()
if variable.metadata.get('foreign') is True:
xs_var_dict[label] = _convert_2_xsimlabvar(var=variable, description_label='label reference / ')
elif variable.metadata.get('foreign') is False:
xs_var_dict[label + '_label'] = _convert_2_xsimlabvar(var=variable, description_label='label / ')
xs_var_dict[label + '_value'] = _convert_2_xsimlabvar(var=variable, intent='out',
value_store=True,
description_label='output of forcing value / ')
return xs_var_dict
[docs]def _make_xso_flux(label, variable):
"""Checks for type of variable defined and calls _convert_2_xsimlabvar function
accordingly. Returns dict with label and xsimlab variable as key/value pairs.
"""
xs_var_dict = defaultdict()
xs_var_dict[label + '_value'] = _convert_2_xsimlabvar(var=variable, intent='out',
value_store=True,
description_label='output of flux value / ')
group = variable.metadata.get('group')
group_to_arg = variable.metadata.get('group_to_arg')
if group:
xs_var_dict[label + '_label'] = _convert_2_xsimlabvar(var=variable, intent='out', groups=group, var_dims=(),
description_label='label reference with group / ',
attrs=False)
if group_to_arg:
xs_var_dict[group_to_arg] = xs.group(group_to_arg)
return xs_var_dict
[docs]def _make_xso_index(label, variable):
"""Checks for type of variable defined and calls _convert_2_xsimlabvar function
accordingly. Returns dict with label and xsimlab variable as key/value pairs.
"""
xs_var_dict = defaultdict()
description_label = 'index / '
# get variable metadata
var_description = variable.metadata.get('description')
if var_description:
description_label = description_label + var_description
var_dims = variable.metadata.get('dims')
if var_dims is None:
raise ValueError("Argument dims is not supplied. Index variable requires passing the labels of dimension to 'dims' keyword.")
if label != var_dims:
raise ValueError("The variable name has to be the same as the dimension it labels. This is a requirement of xarray-simlab.")
if variable.metadata.get('attrs'):
var_attrs = variable.metadata.get('attrs')
else:
var_attrs = {}
xs_var_dict[label] = xs.index(dims=var_dims, description=description_label, attrs=var_attrs)
xs_var_dict[label + '_index'] = _convert_2_xsimlabvar(var=variable, description_label='index / ')
return xs_var_dict
_make_xsimlab_vars = {
XSOVarType.VARIABLE: _make_xso_variable,
XSOVarType.FORCING: _make_xso_forcing,
XSOVarType.PARAMETER: _make_xso_parameter,
XSOVarType.FLUX: _make_xso_flux,
XSOVarType.INDEX: _make_xso_index,
}
[docs]def _create_xsimlab_var_dict(cls_vars):
"""Parses through attributes defined in xso.component decorated class
and extracts those relevant for XSO.
These are initialized as xsimlab variables
and returned in dict to xso.component decorated class.
"""
xs_var_dict = defaultdict()
for key, var in cls_vars.items():
var_type = var.metadata.get('var_type')
var_dict = _make_xsimlab_vars[var_type](key, var)
for xs_key, xs_var in var_dict.items():
xs_var_dict[xs_key] = xs_var
return xs_var_dict
[docs]def _create_forcing_dict(cls, var_dict):
"""Parses var_dict and extracts forcing setup function"""
forcings_dict = defaultdict()
for key, var in var_dict.items():
if var.metadata.get('var_type') is XSOVarType.FORCING:
_forcing_setup_func = var.metadata.get('setup_func')
if _forcing_setup_func is not None:
forcings_dict[key] = getattr(cls, _forcing_setup_func)
return forcings_dict
[docs]def _create_index_dict(cls, var_dict):
"""Parses var_dict and extracts index setup function"""
index_dict = defaultdict()
for key, var in var_dict.items():
if var.metadata.get('var_type') is XSOVarType.INDEX:
index_dict[key] = var
return index_dict
[docs]def _initialize_process_vars(cls, vars_dict):
"""Parses vars_dict of xso.component decorated class, and
initializes defined variables and methods with model backend.
"""
process_label = cls.label
for key, var in vars_dict.items():
var_type = var.metadata.get('var_type')
if var_type is XSOVarType.VARIABLE:
foreign = var.metadata.get('foreign')
if foreign is True:
_label = getattr(cls, key)
elif foreign is False:
_init = getattr(cls, key + '_init')
_label = getattr(cls, key + '_label')
setattr(cls, key, cls.core.add_variable(label=_label, initial_value=_init)) # + '_value'
flux_label = var.metadata.get('flux')
flux_negative = var.metadata.get('negative')
list_input = var.metadata.get('list_input')
if flux_label:
if isinstance(flux_label, list) and isinstance(flux_negative, list):
for _flx_label, _flx_negative in zip(flux_label, flux_negative):
if list_input:
cls.core.add_flux(process_label=cls.label, var_label="list_input", flux_label=_flx_label,
negative=_flx_negative, list_input=_label)
else:
cls.core.add_flux(process_label=cls.label, var_label=_label, flux_label=_flx_label,
negative=_flx_negative)
elif isinstance(flux_label, list) or isinstance(flux_negative, list):
raise ValueError(
f"Variable {_label} was assigned {flux_label} with negative arguments {flux_negative}, "
f"both need to be supplied as list")
else:
if list_input:
cls.core.add_flux(process_label=cls.label, var_label="list_input", flux_label=flux_label,
negative=flux_negative, list_input=_label)
else:
cls.core.add_flux(process_label=cls.label, var_label=_label, flux_label=flux_label,
negative=flux_negative)
elif var_type is XSOVarType.PARAMETER:
if var.metadata.get('foreign') is False:
_par_value = getattr(cls, key)
cls.core.add_parameter(label=process_label + '_' + key, value=_par_value)
else:
raise Exception("Sorry, currently XSO does not support foreign=True for parameters.")
[docs]def _initialize_fluxes(cls, vars_dict):
"""Parses flux variables and methods in xso.component decorated class
and registers them with the model backend.
"""
for key, var in vars_dict.items():
var_type = var.metadata.get('var_type')
if var_type is XSOVarType.FLUX:
flux_func = var.metadata.get('flux_func')
flux_dim = var.metadata.get('dims')
label = cls.label + '_' + flux_func.__name__
if var.metadata.get('group'):
setattr(cls, flux_func.__name__ + '_label', label)
setattr(cls, key + '_value',
cls.core.register_flux(label=label, flux=cls.flux_decorator(flux_func), dims=flux_dim))
[docs]def _initialize_forcings(cls, forcing_dict):
"""Parses xso.forcing variables and methods defined in xso.component decorated class
and registers them with the model backend.
"""
for var, forc_input_func in forcing_dict.items():
forc_label = getattr(cls, var + '_label')
argspec = inspect.getfullargspec(forc_input_func)
input_args = defaultdict()
for arg in argspec.args:
if arg != "self":
input_args[arg] = getattr(cls, arg)
forc_func = forc_input_func(cls, **input_args)
setattr(cls, var + '_value',
cls.core.add_forcing(label=forc_label, forcing_func=forc_func))
[docs]def _initialize_indexes(cls, index_dict):
"""Initializes xso.index variables defined in xso.component decorated class."""
for var, val in index_dict.items():
# apply index value to XSO Index variable type
index_value = getattr(cls, var + '_index')
setattr(cls, var, index_value)
[docs]def _get_init_stage(vars_dict):
"""Returns the initialization stage of the component
attempts to automatically determine init stage from implemented variable types and group arguments
:param vars_dict: dictionary of variables in component
:return: init stage
"""
vars_list = []
groups = 0
groups_to_arg = 0
for key, var in vars_dict.items():
vars_list.append(var.metadata.get('var_type'))
if var.metadata.get('group'):
groups += 1
if var.metadata.get('group_to_arg'):
groups_to_arg += 1
# count number of variables of each type
count_vars = Counter(vars_list)
if groups_to_arg > 0:
init_stage_automated = "fifth"
elif count_vars[XSOVarType.FLUX] > 0:
init_stage_automated = "fourth"
elif count_vars[XSOVarType.FORCING] > 0:
init_stage_automated = "third"
else:
init_stage_automated = "second"
return init_stage_automated
[docs]def _create_new_cls(cls, cls_dict, init_stage):
"""Method to initialize XSO component with appropriate parent class
from xso.backendcomps, which defines initialisation stage and
inherits from Context class.
"""
if init_stage == "first":
new_cls = type(cls.__name__, (FirstInit,), cls_dict)
elif init_stage == "second":
new_cls = type(cls.__name__, (SecondInit,), cls_dict)
elif init_stage == "third":
new_cls = type(cls.__name__, (ThirdInit,), cls_dict)
elif init_stage == "fourth":
new_cls = type(cls.__name__, (FourthInit,), cls_dict)
elif init_stage == "fifth":
new_cls = type(cls.__name__, (FifthInit,), cls_dict)
else:
raise Exception("There was an error with the sorting of processes. The automatic sorting failed.")
return new_cls
[docs]def component(cls=None):
"""A component decorator that adds everything needed to use the class
as a XSO component. It is a wrapper for the xarray-simlab process decorator.
A component represents a logical unit with the model, and usually implements:
- A set of XSO variables, defined as class attributes, e.g. xso.variable, xso.parameter,
xso.forcing or xso.flux.
- One or more methods that can be functions defining a xso.flux or a xso.forcing.
Parameters
__________
cls : class, optional
Allows applying this decorator either as @xso.component or @xso.component(*args).
Returns
_______
cls : class
The decorated class that is a fully functional xso.component.
"""
def create_component(cls):
"""Function to construct new class and return xarray-simlab process"""
attr_cls = attr.attrs(cls, repr=False)
vars_dict = _create_variables_dict(attr_cls)
forcing_dict = _create_forcing_dict(cls, vars_dict)
index_dict = _create_index_dict(cls, vars_dict)
# implement a basic automatic process ordering
init_stage_automated = _get_init_stage(vars_dict)
new_cls = _create_new_cls(cls, _create_xsimlab_var_dict(vars_dict), init_stage_automated)
def flux_decorator(self, func):
"""XSO flux function decorator to unpack arguments"""
@wraps(func)
def unpack_args(**kwargs):
state = kwargs.get('state')
parameters = kwargs.get('parameters')
forcings = kwargs.get('forcings')
input_args = {}
for v_dict in self.flux_input_args['vars']:
if isinstance(v_dict['label'], list) or isinstance(v_dict['label'], np.ndarray):
input_args[v_dict['var']] = [state[label] for label in v_dict['label']]
else:
input_args[v_dict['var']] = state[v_dict['label']]
for v_dict in self.flux_input_args['list_input_vars']:
input_args[v_dict['var']] = np.concatenate([state[label] for label in v_dict['label']],
axis=None)
for v_dict in self.flux_input_args['group_args']:
states = [state[label] for label in v_dict['label']]
if len(states) == 1:
# unpack list to array, for easier handling of single group arg
input_args[v_dict['var']] = states[0]
else:
input_args[v_dict['var']] = [state[label] for label in v_dict['label']]
for p_dict in self.flux_input_args['pars']:
input_args[p_dict['var']] = parameters[p_dict['label']]
for f_dict in self.flux_input_args['forcs']:
input_args[f_dict['var']] = forcings[f_dict['label']]
return func(self, **input_args)
return unpack_args
def initialize(self):
"""Defines xarray-simlab process `initialize` method
that is executed at model runtime.
"""
super(new_cls, self).initialize()
_initialize_process_vars(self, vars_dict)
self.flux_input_args = _create_flux_inputargs_dict(self, vars_dict)
_initialize_fluxes(self, vars_dict)
_initialize_forcings(self, forcing_dict)
_initialize_indexes(self, index_dict)
setattr(new_cls, 'flux_decorator', flux_decorator)
setattr(new_cls, 'initialize', initialize)
# constructed class is passed through xsimlab process decorator:
process_cls = xs.process(new_cls)
# allow passing helper functions through to process class
_forcing_input_functions = [value.__name__ for value in forcing_dict.values()]
cls_dir = dir(cls)
for attribute in cls_dir:
if hasattr(cls, attribute) and callable(getattr(cls, attribute)):
if not attribute.startswith("__") and attribute not in _forcing_input_functions:
# Allow setting custom attr method, to be used in component
setattr(process_cls, attribute, getattr(cls, attribute))
return process_cls
if cls:
return create_component(cls)
return create_component