Fitting¶
This module defines the linear readout used to decode the target functions from the networks’ state. This module is based on scikit-learn’s linear model module and can be extended to other (possibly non-linear) regression methods as needed.
- class hetero.fitting.Decoder(readout_params, ncpu=-1, **kwargs)[source]¶
Bases:
objectA decoder class for fitting and predicting using various regression methods.
- Parameters:
readout_params (dict) –
Dictionary containing readout parameters:
’readout_method’: str, regression method (‘ridge’, ‘lasso’, ‘linreg’, ‘logreg’)
’readout_has_bias’: bool, whether to include bias term
’readout_has_bias’: float, regularization parameter (for ridge and lasso)
ncpu (int, optional) – Number of CPU cores to use, by default -1 (all cores)
Warning
The readout method must be compatible with the problem at hand. No sanity checks are performed to assess that is the case or not.
- property coeff¶
Get the model coefficients.
- hetero.fitting.det_coeff(yhat, y, eps=1e-09)[source]¶
Calculate the coefficient of determination (R-squared) score.
Handles both single and multi-dimensional targets, with built-in outlier handling. For multi-dimensional data, computes R-squared separately for each dimension. Relevant from regression problems.
- Parameters:
yhat (array-like) – Predicted values
y (array-like) – True values
eps (float, optional) – Small constant to avoid division by zero, default 1e-9
- Returns:
R-squared score(s). Returns array for multi-dimensional targets.
- Return type:
float or array-like