import argparse
from copy import deepcopy
from hetero.config import CFG, VALID_ACTIVATIONS
[docs]
def get_parser():
"""
Creates a parser that is able to capture all variables necessary to construct the paramter dictionaries.
"""
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--root', type=str, default=CFG['root'], help='Root directory to store results')
parser.add_argument('--stype', type=str, default=CFG['stype'], help="Data type for saving the network state. Set to float64 for accurate forecasting benchmark, otherwise leave the default (float16).")
parser.add_argument('--solver', type=str, default=CFG['solver'], help='Solver for integration')
# Heterogeneiy
parser.add_argument('--distro', type=str, default=CFG['distro'], help='Heterogeniety distribution (e.g., lognormal, gamma, uniform)')
parser.add_argument('--n_means', type=int, default=CFG['n_means'], help='Number of mean timescales')
parser.add_argument('--n_vars', type=int, default=CFG['n_vars'], help='Number of heterogeneity levels')
parser.add_argument('--log_distance', type=float, default=CFG['log_distance'], help='Log distance between mean levels')
# network
parser.add_argument('--N', type=int, default=CFG['N'], help='Number of neurons in the network')
parser.add_argument('--p', type=float, default=CFG['p'], help='Sparsity of the connectivity matrix')
parser.add_argument('--f', type=float, default=CFG['f'], help='Fraction of excitatory neurons')
parser.add_argument('--mue', type=float, default=CFG['mue'], help='Mean excitatory weight')
parser.add_argument('--sig0', type=float, default=CFG['sig0'], help='Excitatory and inhibitory weight dispersion (std)')
parser.add_argument('--autapse', type=bool, default=CFG['autapse'], help='Are autapses allowed?')
parser.add_argument('--delay', type=int, default=CFG['delay'], help='Synaptic delay (in unit of dt)')
# stim
parser.add_argument('--stim', action='extend', nargs='+', default=[], help="The label of the inputs to the network. More than one label can be given to concatenate different timeseries.")
# task
parser.add_argument('--task', type=str, choices=list(CFG['task'].keys()), default='tayloramus', help="The task network is supposed to perform.")
# dataset
parser.add_argument('--n_trn', type=float, default=CFG['n_trn'], help='Length of the training set (in unit of the input base timescale)')
parser.add_argument('--n_tst', type=float, default=CFG['n_tst'], help='Length of the test set (in unit of the input base timescale)')
parser.add_argument('--n_trials', type=int, default=CFG['n_trials'], help='The number of times the training set is extended for performance uncertainty estimation.')
parser.add_argument('--scale_by_size', action='store_true', default=CFG['scale_by_size'], help='Whether to scale the training set size by the network size. Default True.')
# gain
parser.add_argument('--Jn', type=float, default=CFG['Jn'], help='Gain of the white noise')
parser.add_argument('--J', type=float, default=CFG['J'], help='Gain of the recurrent weights')
parser.add_argument('--Ju', type=float, default=CFG['Ju'], help='Gain of the feedforward weights')
# dynamics
parser.add_argument('--dyn', type=str, choices=list(CFG['dyn'].keys()), default='LI', help='The name of network dynamic.')
# readout (fit)
parser.add_argument('--readout_method', type=str, default=CFG['readout_method'], help='Readout regression type (ridge or direct)')
parser.add_argument('--readout_regularizer', type=float, default=CFG['readout_regularizer'], help='Ridge regularization value')
parser.add_argument('--readout_has_bias', type=bool, default=CFG['readout_has_bias'], help='Whether to fit a bias term? Default True.')
parser.add_argument('--readout_activation', type=str, choices=list(CFG['readout_activation'].keys()), default='possig1', help='The activation function used for the readout.')
parser.add_argument('--ncore', type=int, default=-1, help='Number of CPUs to use. -1 means use all.')
parser.add_argument('--noint', action='store_true', default=False, help='Skip integration. Default False.')
parser.add_argument('--nofit', action='store_true', default=False, help='Skip fitting. Default False.')
parser.add_argument('--noviz', action='store_true', default=False, help='Skip visualization. Default False.')
return parser
[docs]
def create_param_dicts(args, verbose=True):
"""
Creates proper dictionaries from args
"""
hetero_params = {
'distro': args.distro,
'n_means': args.n_means,
'n_vars': args.n_vars,
'log_distance': args.log_distance,
}
net_params = {
'N': args.N,
'p': args.p,
'f': args.f,
'mue': args.mue,
'sig0': args.sig0,
'autapse': args.autapse,
'topology': 'random',
'delay': args.delay
}
gain_params = {
'Jn': args.Jn,
'J': args.J,
'Ju': args.Ju
}
task_params = {
args.task : deepcopy(CFG['task'][args.task])
}
stim_params = {}
for stim in args.stim:
stim_params[stim] = deepcopy(CFG['stim'][stim])
dataset_params = {
'n_trn': args.n_trn,
'n_tst': args.n_tst,
'n_trials': args.n_trials,
'scale_by_size': args.scale_by_size,
}
dyn_params = {
args.dyn : deepcopy(CFG['dyn'][args.dyn])
}
# ensure that the readout activation and the dynamic are compatible
# LIF --> covolution
# LI --> activaition function
if args.readout_activation not in VALID_ACTIVATIONS[args.dyn]:
if args.dyn =='LI':
print(f'Warning: The readout of a rate neuron shall not be done via {args.readout_activation}. Switched back to possig1.')
args.readout_activation = 'possig1'
else:
print(f'Warning: The readout of a spiking neuron shall not be done {args.readout_activation} activation function. Switched back to d10.')
args.readout_activation = 'd10'
readout_params = {
'readout_method': args.readout_method,
'readout_regularizer': args.readout_regularizer,
'readout_has_bias': args.readout_has_bias,
'readout_activation': {
args.readout_activation: deepcopy(CFG['readout_activation'][args.readout_activation])
}
}
params = {
'heterog': hetero_params,
'net': net_params,
'set': dataset_params,
'gain': gain_params,
'readout': readout_params,
'task': task_params,
'dyn': dyn_params,
'stim': stim_params,
}
return params