Source code for hetero.dynamics.li

"""
Leaky-Integrator dynamic
------------------------
The evolution of the membrane voltage :math:`v_i(t)` of neuron :math:`i` is described by the following rule: 

.. math::

    \\tau_i \\frac{\\mathrm{d}v_i}{\\mathrm{d}t} = -v_i(t) + J \\sum_j w_{ij} f(v_j(t)) + J_u \\sum_k w_{ik} u_k(t) + J_n \\xi(t)


where :math:`f(.)` is some activation function and :math:`'xi` is the white noise imposed on the neuron. Note that time constants are in general neuron-specific, which renders the network heterogeneous. 
"""



import os
osjoin = os.path.join # an alias for convenience
from pdb import set_trace
import warnings 
warnings.simplefilter('error', RuntimeWarning)

import time
TIME_FMT = "%H:%M:%S"

import numpy as np
from scipy.integrate import solve_ivp
# from multiprocessing import get_context#, Pool, current_process, Manager
from joblib import Parallel, delayed
# try:
#     import dask
#     from dask.distributed import Client
# except ModuleNotFoundError:
#     pass

from hetero import utils
from hetero.dynamics.base import Dynamic

[docs] class LI(Dynamic): """ The object for integrating neworks of Leaking-Integrator (LI) neurons. """ def __init__(self, sim_path, namespace, dyn_params={}): super().__init__(sim_path=sim_path, namespace=namespace, name='li', ext='npy') self.namespace.update(dyn_params)
[docs] def run(self, solver, state_name, stype=float, ncpu=-1, **kwargs): """Runs the LI dynamic. Parameters ---------- solver : str Integrator's name (euler or runge_kutta) state_name : str State file name stype : str or type, optional The data type in which the state is stoerd on the disk, by default float ncpu : int, optional Number of CPU cores to use. The default of -1 means use all. """ # checking remaining nets net_id_list, taus_list = self.scan_for_undone(state_name) if len(taus_list) == 0: return 1 # exit if all is already computed # updating the namespace with callable functions u_ff = self.namespace['u'] def ut(t): return u_ff[:, t] xi = self.namespace['xi'] def xit(t): if xi is None: return np.zeros(self.namespace['N']) else: return xi[:, t] self.namespace['u'] = ut self.namespace['xi']= xit self.namespace.update(**kwargs) # to add additional af arguments integrator = forward_euler if solver=='euler' else runge_kutta # apparently rk can be run in parallel as well # if solver!= 'euler': # ncpu = 1 # rk cannot be run in parallel print(f'Integration is performed with {solver}.') # we better immediately save the results to free up the ram. # so the wrappers are not just integratores, but integrate and # savers (hence the name i.&.s). if ncpu== 1: print('No parallelization is being used. Run serially.') integrate_n_save_serial( integrator=integrator, ncpu=ncpu, namespace=self.namespace, sim_path=self.sim_path, taus=taus_list, net_ids=net_id_list, state_name=state_name, stype=stype) else: print('joblib is being used.') integrate_n_save_joblib( integrator=integrator, ncpu=ncpu, namespace=self.namespace, sim_path=self.sim_path, taus=taus_list, net_ids=net_id_list, state_name=state_name, stype=stype)
[docs] def LI_grad(t, y, taus=None, A_adj=None, wu=None, w=None, Ju = 1, J = 1, Jn=1, af = 'possig', ut = None, xit = None, dt = 1., delay = 0, **kwargs): """Implentation of the LI dynamics with noise.""" dy = np.zeros_like(y) _fs = activation_func(y, af, **kwargs) _ut = ut if not callable(ut) else ut(int(max(0, t-delay)//dt)) _xit = xit if not callable(xit) else xit(int(t//dt)) dy = -y dy += J * A_adj._with_data(w) @ _fs dy += Ju * wu @ _ut # dy will be later multiplied by dt, so for a correct Euler-Maruyama # the correct sacling for noise must be 1/sqrt(dt) for dy += Jn * _xit / np.sqrt(dt) dy /= taus return dy
[docs] def integrate_n_save_joblib(integrator, namespace, taus, net_ids, sim_path, state_name, stype, ncpu): with Parallel(n_jobs=ncpu, return_as="generator") as parallel: res_gen=parallel(delayed(_ins)( integrator=integrator, namespace=namespace, taus = tau, net_id=net_id, sim_path=sim_path, state_name=state_name, stype=stype) for tau, net_id in zip(taus, net_ids)) results = list(res_gen)
[docs] def integrate_n_save_serial(integrator, namespace, taus, net_ids, sim_path, state_name, stype, ncpu): for net_id, tau in zip(net_ids, taus): _ins(integrator=integrator, namespace=namespace, taus = tau, net_id=net_id, sim_path=sim_path, state_name=state_name, stype=stype)
def _ins(integrator, namespace, taus, net_id, sim_path, state_name, stype): """A wrapper arount the integrator that saves the results and prints the solver's message.""" y, msg = integrator(taus=taus, **namespace) utils.asave(osjoin(sim_path, net_id, f'{state_name}.npy'), y, stype) print(f'{net_id}: {msg}')
[docs] def forward_euler(**arg_dict): """Implementation of the forward Euler solver.""" dt = arg_dict['dt'] n_timesteps = arg_dict['n_timesteps'] # d = arg_dict['delay'] # check if that is integer y0 = arg_dict['y0'] y = np.empty((len(y0), n_timesteps), dtype=y0.dtype) y[:,0] = y0 for i in range(1, n_timesteps): t = i*dt y[:, i] = y[:, i-1] + dt * LI_grad(t, y[:, i-1], ut = arg_dict['u'],#,(i-d), xit = arg_dict['xi'],#(i), # dt = dt, # delay = arg_dict['delay'], **arg_dict ) if arg_dict['verbose'] and (i+1)%(n_timesteps//5) == 0: print(f'\t{(i+1)*100./n_timesteps}% of {n_timesteps} timesteps completed.') msg = 'Forward Euler computed succesfully.' return y, msg
[docs] def runge_kutta(**arg_dict): """Implementation of the Runge-Kutta solver. This function is a wrapper around Scipy's ``solve_ivp`` method.""" dt = arg_dict['dt'] n_timesteps = arg_dict['n_timesteps'] t = np.linspace(0, n_timesteps*dt, n_timesteps) # d = arg_dict['delay'] # check if that is integer y0 = arg_dict['y0'] if not arg_dict['verbose']: sol = solve_ivp(LI_grad, y0=y0, method='RK45', t_span = [0, t[-1]], t_eval = t, args= [ # mandatory arguments (don't change the order) arg_dict['taus'], arg_dict['A_adj'], arg_dict['wu'], arg_dict['w'], arg_dict['Ju'], arg_dict['J'], arg_dict['Jn'], arg_dict['af'], arg_dict['u'], arg_dict['xi'], arg_dict['dt'], ], first_step = dt/100, # rtol = 1e-6, # atol=1e-2 ) y = sol.y msg = sol.message else: chunks = np.array_split(t, 5) ys = [] y_current = y0 for i, chunk in enumerate(chunks): t_start = chunk[0] t_end = chunk[-1] tic = time.time() sol = solve_ivp(LI_grad, y0=y_current, method='RK45', t_span = [t_start, t_end], t_eval = chunk, args= [ # mandatory arguments (don't change the order) arg_dict['taus'], arg_dict['A_adj'], arg_dict['wu'], arg_dict['w'], arg_dict['Ju'], arg_dict['J'], arg_dict['Jn'], arg_dict['af'], arg_dict['u'], arg_dict['xi'], arg_dict['dt'], ], first_step = dt, # rtol = 1e-3, # atol=1e-2 ) ys.append(sol.y) y_current = sol.y[:, -1] print(f'{(i+1)*100./len(chunks)}% of integration ({len(chunk)}) completed in {time.time()-tic}s with the following message:') print(sol.message) y = np.concatenate(ys, axis=1) msg = sol.message print(f'---------------- {y.shape, sol.y.shape, len(ys)}') return y, msg
[docs] def activation_func(x, af, x0=1., thr=0, **kwargs): """ Apply a specified activation function to input data with optional scaling and thresholding. The valid activation fucntions are: - 'lin': linear - 'relu': rectified linear unit - 'caplin' : same as 'lin' but capped between 0 and 1 - 'caprelu': same as 'relu' but capped between 0 and 1 - 'capquad': quadratic function capped between 0 and 1 - 'capcube': cubic function capped between 0 and 1 - 'capcube': cubic function capped between 0 and 1 - 'capexp': exponential function capped between 0 and 1 - 'tanh': tangent hyperbolic - `symsig`: an even sigmoidal function centered at t=0 with range between -1 and 1. - 'possig': sigmoidal function centered at y=0.5 with range between 0 and 1 . - 'possig_plus': same as 'possig' but only for values larger than threshold and zero otherwise. - 'bell' : non-normalized Gaussian activation function - 'camel': a gaussian multiplied by a pure quadratic function, resulting in two identical bumps at (±1, 1). - 'switch': a binary activation. Zero for values smaller than thr and 1 otherwise. Parameters ---------- x : float or ndarray Input array or value to be transformed af: str Name of the activation function to apply (e.g., 'relu', 'tanh', 'possig', etc.) x0: float Scaling factor for the input (default: 1.0) thr: float Threshold (bias) value to shift the input (default: 0) kwargs: dict Additional keyword arguments for specific activation functions Returns ------- float or ndarray The result of applying the chosen activation function to the input """ if af=='lin': # linear return x/x0 elif af=='relu': # relu return np.maximum(0, (x-thr)/x0) elif af=='caprelu': # capped relu return np.clip((x-thr)/x0, 0, 1) elif af=='caplin': # capped linear return np.clip((x-thr)/x0, 0, 1) elif af=='capquad': # capped linear return np.clip(((x-thr)/x0)**2, 0, 1) elif af=='capcube': # capped linear return np.clip(((x-thr)/x0)**3, 0, 1) elif af=='capexp': # capped exponential return np.exp(np.minimum((x-thr)/x0, 0)) elif af=='tanh': # hyperbolic tangent return np.tanh((x-thr)/x0) elif af=='possig': # positive sigmoid """Clipping at 11 is to avoid overflow""" return 1./(1+np.exp(-np.clip((x-thr)/x0,-60, 60) )) elif af=='possig_plus': # positive sigmoid for positive values only, zero elsewhere return 1./(1+np.exp(-np.clip((x-thr)/x0,-60, 60))) * (x>thr) elif af=='symsig': # symmetric sigmoid """Clipping at 11 is to avoid overflow""" return 2./(1+np.exp(-np.clip((x-thr)/x0,-60, 60) )) -1 elif af=='bell': # capped linear """Clipping at 11 is to avoid overflow""" return np.exp(- np.clip((x-thr)/x0,-60, 60)**2) elif af=='camel': """Clipping at 11 is to avoid overflow""" return np.exp(- np.clip((x-thr)/x0,-60, 60)**2 + 1) * (x/x0)**2 elif af=='switch': return np.where(x-thr>x0, 1, 0) else: NotImplementedError('Activation function {} is not implemented.'.format(af))