"""
Leaky-Integrator dynamic
------------------------
The evolution of the membrane voltage :math:`v_i(t)` of neuron :math:`i` is described by the following rule:
.. math::
\\tau_i \\frac{\\mathrm{d}v_i}{\\mathrm{d}t} = -v_i(t) + J \\sum_j w_{ij} f(v_j(t)) + J_u \\sum_k w_{ik} u_k(t) + J_n \\xi(t)
where :math:`f(.)` is some activation function and :math:`'xi` is the white noise imposed on the neuron. Note that time constants are in general neuron-specific, which renders the network heterogeneous.
"""
import os
osjoin = os.path.join # an alias for convenience
from pdb import set_trace
import warnings
warnings.simplefilter('error', RuntimeWarning)
import time
TIME_FMT = "%H:%M:%S"
import numpy as np
from scipy.integrate import solve_ivp
# from multiprocessing import get_context#, Pool, current_process, Manager
from joblib import Parallel, delayed
# try:
# import dask
# from dask.distributed import Client
# except ModuleNotFoundError:
# pass
from hetero import utils
from hetero.dynamics.base import Dynamic
[docs]
class LI(Dynamic):
"""
The object for integrating neworks of Leaking-Integrator (LI) neurons.
"""
def __init__(self, sim_path, namespace, dyn_params={}):
super().__init__(sim_path=sim_path, namespace=namespace, name='li', ext='npy')
self.namespace.update(dyn_params)
[docs]
def run(self, solver, state_name, stype=float, ncpu=-1, **kwargs):
"""Runs the LI dynamic.
Parameters
----------
solver : str
Integrator's name (euler or runge_kutta)
state_name : str
State file name
stype : str or type, optional
The data type in which the state is stoerd on the disk, by default float
ncpu : int, optional
Number of CPU cores to use. The default of -1 means use all.
"""
# checking remaining nets
net_id_list, taus_list = self.scan_for_undone(state_name)
if len(taus_list) == 0:
return 1 # exit if all is already computed
# updating the namespace with callable functions
u_ff = self.namespace['u']
def ut(t):
return u_ff[:, t]
xi = self.namespace['xi']
def xit(t):
if xi is None:
return np.zeros(self.namespace['N'])
else:
return xi[:, t]
self.namespace['u'] = ut
self.namespace['xi']= xit
self.namespace.update(**kwargs) # to add additional af arguments
integrator = forward_euler if solver=='euler' else runge_kutta
# apparently rk can be run in parallel as well
# if solver!= 'euler':
# ncpu = 1 # rk cannot be run in parallel
print(f'Integration is performed with {solver}.')
# we better immediately save the results to free up the ram.
# so the wrappers are not just integratores, but integrate and
# savers (hence the name i.&.s).
if ncpu== 1:
print('No parallelization is being used. Run serially.')
integrate_n_save_serial(
integrator=integrator, ncpu=ncpu,
namespace=self.namespace,
sim_path=self.sim_path,
taus=taus_list, net_ids=net_id_list,
state_name=state_name, stype=stype)
else:
print('joblib is being used.')
integrate_n_save_joblib(
integrator=integrator, ncpu=ncpu,
namespace=self.namespace,
sim_path=self.sim_path,
taus=taus_list, net_ids=net_id_list,
state_name=state_name, stype=stype)
[docs]
def LI_grad(t, y,
taus=None,
A_adj=None,
wu=None, w=None,
Ju = 1, J = 1,
Jn=1,
af = 'possig',
ut = None, xit = None,
dt = 1.,
delay = 0,
**kwargs):
"""Implentation of the LI dynamics with noise."""
dy = np.zeros_like(y)
_fs = activation_func(y, af, **kwargs)
_ut = ut if not callable(ut) else ut(int(max(0, t-delay)//dt))
_xit = xit if not callable(xit) else xit(int(t//dt))
dy = -y
dy += J * A_adj._with_data(w) @ _fs
dy += Ju * wu @ _ut
# dy will be later multiplied by dt, so for a correct Euler-Maruyama
# the correct sacling for noise must be 1/sqrt(dt) for
dy += Jn * _xit / np.sqrt(dt)
dy /= taus
return dy
[docs]
def integrate_n_save_joblib(integrator, namespace,
taus, net_ids,
sim_path, state_name,
stype, ncpu):
with Parallel(n_jobs=ncpu, return_as="generator") as parallel:
res_gen=parallel(delayed(_ins)(
integrator=integrator, namespace=namespace,
taus = tau, net_id=net_id,
sim_path=sim_path, state_name=state_name,
stype=stype) for tau, net_id in zip(taus, net_ids))
results = list(res_gen)
[docs]
def integrate_n_save_serial(integrator, namespace,
taus, net_ids,
sim_path, state_name,
stype, ncpu):
for net_id, tau in zip(net_ids, taus):
_ins(integrator=integrator, namespace=namespace,
taus = tau, net_id=net_id,
sim_path=sim_path, state_name=state_name,
stype=stype)
def _ins(integrator, namespace, taus, net_id, sim_path, state_name, stype):
"""A wrapper arount the integrator that saves the results and prints the solver's message."""
y, msg = integrator(taus=taus, **namespace)
utils.asave(osjoin(sim_path, net_id, f'{state_name}.npy'), y, stype)
print(f'{net_id}: {msg}')
[docs]
def forward_euler(**arg_dict):
"""Implementation of the forward Euler solver."""
dt = arg_dict['dt']
n_timesteps = arg_dict['n_timesteps']
# d = arg_dict['delay'] # check if that is integer
y0 = arg_dict['y0']
y = np.empty((len(y0), n_timesteps), dtype=y0.dtype)
y[:,0] = y0
for i in range(1, n_timesteps):
t = i*dt
y[:, i] = y[:, i-1] + dt * LI_grad(t, y[:, i-1],
ut = arg_dict['u'],#,(i-d),
xit = arg_dict['xi'],#(i),
# dt = dt,
# delay = arg_dict['delay'],
**arg_dict
)
if arg_dict['verbose'] and (i+1)%(n_timesteps//5) == 0:
print(f'\t{(i+1)*100./n_timesteps}% of {n_timesteps} timesteps completed.')
msg = 'Forward Euler computed succesfully.'
return y, msg
[docs]
def runge_kutta(**arg_dict):
"""Implementation of the Runge-Kutta solver. This function is a wrapper
around Scipy's ``solve_ivp`` method."""
dt = arg_dict['dt']
n_timesteps = arg_dict['n_timesteps']
t = np.linspace(0, n_timesteps*dt, n_timesteps)
# d = arg_dict['delay'] # check if that is integer
y0 = arg_dict['y0']
if not arg_dict['verbose']:
sol = solve_ivp(LI_grad, y0=y0, method='RK45',
t_span = [0, t[-1]], t_eval = t,
args= [
# mandatory arguments (don't change the order)
arg_dict['taus'],
arg_dict['A_adj'],
arg_dict['wu'],
arg_dict['w'],
arg_dict['Ju'],
arg_dict['J'],
arg_dict['Jn'],
arg_dict['af'],
arg_dict['u'],
arg_dict['xi'],
arg_dict['dt'],
],
first_step = dt/100,
# rtol = 1e-6,
# atol=1e-2
)
y = sol.y
msg = sol.message
else:
chunks = np.array_split(t, 5)
ys = []
y_current = y0
for i, chunk in enumerate(chunks):
t_start = chunk[0]
t_end = chunk[-1]
tic = time.time()
sol = solve_ivp(LI_grad, y0=y_current, method='RK45',
t_span = [t_start, t_end], t_eval = chunk,
args= [
# mandatory arguments (don't change the order)
arg_dict['taus'],
arg_dict['A_adj'],
arg_dict['wu'],
arg_dict['w'],
arg_dict['Ju'],
arg_dict['J'],
arg_dict['Jn'],
arg_dict['af'],
arg_dict['u'],
arg_dict['xi'],
arg_dict['dt'],
],
first_step = dt,
# rtol = 1e-3,
# atol=1e-2
)
ys.append(sol.y)
y_current = sol.y[:, -1]
print(f'{(i+1)*100./len(chunks)}% of integration ({len(chunk)}) completed in {time.time()-tic}s with the following message:')
print(sol.message)
y = np.concatenate(ys, axis=1)
msg = sol.message
print(f'---------------- {y.shape, sol.y.shape, len(ys)}')
return y, msg
[docs]
def activation_func(x, af, x0=1., thr=0, **kwargs):
"""
Apply a specified activation function to input data with optional scaling and thresholding.
The valid activation fucntions are:
- 'lin': linear
- 'relu': rectified linear unit
- 'caplin' : same as 'lin' but capped between 0 and 1
- 'caprelu': same as 'relu' but capped between 0 and 1
- 'capquad': quadratic function capped between 0 and 1
- 'capcube': cubic function capped between 0 and 1
- 'capcube': cubic function capped between 0 and 1
- 'capexp': exponential function capped between 0 and 1
- 'tanh': tangent hyperbolic
- `symsig`: an even sigmoidal function centered at t=0 with range between -1 and 1.
- 'possig': sigmoidal function centered at y=0.5 with range between 0 and 1 .
- 'possig_plus': same as 'possig' but only for values larger than threshold and zero otherwise.
- 'bell' : non-normalized Gaussian activation function
- 'camel': a gaussian multiplied by a pure quadratic function, resulting in two identical bumps
at (±1, 1).
- 'switch': a binary activation. Zero for values smaller than thr and 1 otherwise.
Parameters
----------
x : float or ndarray
Input array or value to be transformed
af: str
Name of the activation function to apply (e.g., 'relu', 'tanh', 'possig', etc.)
x0: float
Scaling factor for the input (default: 1.0)
thr: float
Threshold (bias) value to shift the input (default: 0)
kwargs: dict
Additional keyword arguments for specific activation functions
Returns
-------
float or ndarray
The result of applying the chosen activation function to the input
"""
if af=='lin': # linear
return x/x0
elif af=='relu': # relu
return np.maximum(0, (x-thr)/x0)
elif af=='caprelu': # capped relu
return np.clip((x-thr)/x0, 0, 1)
elif af=='caplin': # capped linear
return np.clip((x-thr)/x0, 0, 1)
elif af=='capquad': # capped linear
return np.clip(((x-thr)/x0)**2, 0, 1)
elif af=='capcube': # capped linear
return np.clip(((x-thr)/x0)**3, 0, 1)
elif af=='capexp': # capped exponential
return np.exp(np.minimum((x-thr)/x0, 0))
elif af=='tanh': # hyperbolic tangent
return np.tanh((x-thr)/x0)
elif af=='possig': # positive sigmoid
"""Clipping at 11 is to avoid overflow"""
return 1./(1+np.exp(-np.clip((x-thr)/x0,-60, 60) ))
elif af=='possig_plus': # positive sigmoid for positive values only, zero elsewhere
return 1./(1+np.exp(-np.clip((x-thr)/x0,-60, 60))) * (x>thr)
elif af=='symsig': # symmetric sigmoid
"""Clipping at 11 is to avoid overflow"""
return 2./(1+np.exp(-np.clip((x-thr)/x0,-60, 60) )) -1
elif af=='bell': # capped linear
"""Clipping at 11 is to avoid overflow"""
return np.exp(- np.clip((x-thr)/x0,-60, 60)**2)
elif af=='camel':
"""Clipping at 11 is to avoid overflow"""
return np.exp(- np.clip((x-thr)/x0,-60, 60)**2 + 1) * (x/x0)**2
elif af=='switch':
return np.where(x-thr>x0, 1, 0)
else:
NotImplementedError('Activation function {} is not implemented.'.format(af))