Source code for xso.model

from collections import defaultdict

import numpy as np


[docs]def return_dim_ndarray(value): """Helper function to always have at least 1d numpy array returned.""" if isinstance(value, list): return np.array(value) elif isinstance(value, np.ndarray): return value else: return np.array([value])
[docs]class Model: """Base model class, containing dictionaries of all model variables and components, as well as the function that sorts through all of these, and computes each step at model runtime. It is instantiated once at model setup and filled with the specific variables defined in components and their supplied labels and values. The Model class is stored in the model backend and shared and written to by all components. """
[docs] def __init__(self): """Initializing defaultdicts to store model variables flexibly, before initializing them in Xarray-simlab backend. """ self.time = None self.variables = defaultdict() self.parameters = defaultdict() self.forcing_func = defaultdict() self.forcings = defaultdict() self.fluxes = defaultdict() self.flux_values = defaultdict() self.fluxes_per_var = defaultdict(list) self.var_dims = defaultdict() self.flux_dims = defaultdict() self.full_model_dims = defaultdict()
[docs] def __repr__(self): """Simple repr implementation that prints model components""" return (f"Model contains: \n" f"Variables:{[var for var in self.variables]} \n" f"Parameters:{[par for par in self.parameters]} \n" f"Forcings:{[forc for forc in self.forcings]} \n" f"Fluxes:{[flx for flx in self.fluxes]} \n" f"Full Model Dimensions:{[(state, dim) for state, dim in self.full_model_dims.items()]} \n")
[docs] def unpack_flat_state(self, flat_state): """Function called at the beginning of the model_function, to convert array of model values into a labeled dictionary. This allows for easier calculations, and ensures compatibility to most solving algorithms. """ state_dict = defaultdict() index = 0 for key, dims in self.full_model_dims.items(): if dims is None: state_dict[key] = flat_state[index] index += 1 elif isinstance(dims, int): state_dict[key] = flat_state[index:index + dims] index += dims else: _length = np.prod(dims) state_dict[key] = flat_state[index:index + _length].reshape(dims) index += _length return state_dict
[docs] def model_function(self, time=None, current_state=None, forcing=None): """ General model function that computes forcings and fluxes. Is called within solve function of Solver. Parameters __________ current_state : numpy array Large array containing all current values for model. time : numpy array or None Can be passed explicitly for certain solvers, e.g. odeint solver passes array of model time. forcing : dict or None Can be passed explicitly for certain solvers, e.g. stepwise solver evaluates current time step value and passes that as dict. """ # unpack flat state: state = self.unpack_flat_state(current_state) # Return forcings for time point: if time is not None: forcing_now = defaultdict() for key, func in self.forcing_func.items(): forcing_now[key] = func(time) forcing = forcing_now elif forcing is None: forcing = self.forcings # Compute fluxes: flux_values = defaultdict() fluxes_out = [] for flx_label, flux in self.fluxes.items(): _value = return_dim_ndarray(flux(state=state, parameters=self.parameters, forcings=forcing)) flux_values[flx_label] = _value fluxes_out.append(_value) if flx_label in state: state.update({flx_label: _value}) # Route list input fluxes: list_input_fluxes = defaultdict(list) for flux_var_dict in self.fluxes_per_var["list_input"]: flux_label, negative, list_input = flux_var_dict.values() flux_val = flux_values[flux_label] flux_dims = self.full_model_dims[flux_label] list_var_dims = [] for var in list_input: _dim = self.full_model_dims[var] list_var_dims.append(_dim or 1) if len(list_input) == flux_dims: for var, flux in zip(list_input, flux_val): if negative: list_input_fluxes[var].append(-flux) else: list_input_fluxes[var].append(flux) elif sum(list_var_dims) == flux_dims: _dim_counter = 0 for var, dims in zip(list_input, list_var_dims): flux = flux_val[_dim_counter:_dim_counter + dims] _dim_counter += dims if negative: list_input_fluxes[var].append(-flux) else: list_input_fluxes[var].append(flux) else: raise Exception("ERROR: list input vars dims and flux output dims do not match") # Assign fluxes to variables: state_out = [] for var_label, value in self.variables.items(): var_fluxes = [] dims = self.full_model_dims[var_label] flux_applied = False if var_label in self.fluxes_per_var: flux_applied = True for flux_var_dict in self.fluxes_per_var[var_label]: flux_label, negative, list_input = flux_var_dict.values() if dims: _flux = flux_values[flux_label] else: _flux = np.sum(flux_values[flux_label]) if negative: var_fluxes.append(-_flux) else: var_fluxes.append(_flux) if var_label in list_input_fluxes: flux_applied = True for flux in list_input_fluxes[var_label]: if dims: _flux = flux else: _flux = np.sum(flux) var_fluxes.append(_flux) if not flux_applied: dims = self.full_model_dims[var_label] if dims: var_fluxes.append(np.array([0 for i in range(dims)])) else: var_fluxes.append(0) state_out.append(np.sum(var_fluxes, axis=0)) # flatten state again: full_output = np.concatenate([[v for val in state_out for v in val.ravel()], [v for val in fluxes_out for v in val.ravel()]], axis=None) return full_output