Fitting

This module defines the linear readout used to decode the target functions from the networks’ state. This module is based on scikit-learn’s linear model module and can be extended to other (possibly non-linear) regression methods as needed.

class hetero.fitting.Decoder(readout_params, ncpu=-1, **kwargs)[source]

Bases: object

A decoder class for fitting and predicting using various regression methods.

Parameters:
  • readout_params (dict) –

    Dictionary containing readout parameters:

    • ’readout_method’: str, regression method (‘ridge’, ‘lasso’, ‘linreg’, ‘logreg’)

    • ’readout_has_bias’: bool, whether to include bias term

    • ’readout_has_bias’: float, regularization parameter (for ridge and lasso)

  • ncpu (int, optional) – Number of CPU cores to use, by default -1 (all cores)

Warning

The readout method must be compatible with the problem at hand. No sanity checks are performed to assess that is the case or not.

fit(X, y)[source]

Fit the model using features X and target y.

predict(X)[source]

Predict using fitted model on features X.

score(yhat, y)[source]

Compute score between predictions yhat and true values y.

property coeff

Get the model coefficients.

hetero.fitting.det_coeff(yhat, y, eps=1e-09)[source]

Calculate the coefficient of determination (R-squared) score.

Handles both single and multi-dimensional targets, with built-in outlier handling. For multi-dimensional data, computes R-squared separately for each dimension. Relevant from regression problems.

Parameters:
  • yhat (array-like) – Predicted values

  • y (array-like) – True values

  • eps (float, optional) – Small constant to avoid division by zero, default 1e-9

Returns:

R-squared score(s). Returns array for multi-dimensional targets.

Return type:

float or array-like

hetero.fitting.conf_mat(yhat, y)[source]

Compute confusion matrix between predicted and true labels. Relevant from classification problems.

Parameters:
  • yhat (array-like) – Predicted values

  • y (array-like) – True values

Returns:

Confusion matrix

Return type:

2d array