Source code for hetero.fitting

"""
This module defines the linear readout used to decode the target functions from the networks' state. This module is based on scikit-learn's linear model module and can be extended to other (possibly non-linear) regression methods as needed.
"""

import os 
osjoin = os.path.join
from pdb import set_trace

import numpy as np
from sklearn import linear_model
from sklearn.metrics import confusion_matrix, r2_score

from hetero.utils import timer 
    
[docs] class Decoder(object): """ A decoder class for fitting and predicting using various regression methods. Parameters ---------- readout_params : dict Dictionary containing readout parameters: - 'readout_method': str, regression method ('ridge', 'lasso', 'linreg', 'logreg') - 'readout_has_bias': bool, whether to include bias term - 'readout_has_bias': float, regularization parameter (for ridge and lasso) ncpu : int, optional Number of CPU cores to use, by default -1 (all cores) .. warning:: The readout method must be compatible with the problem at hand. No sanity checks are performed to assess that is the case or not. """ def __init__(self, readout_params, ncpu=-1, **kwargs): self.method = readout_params['readout_method'] self.has_bias = readout_params['readout_has_bias'] self.alpha = readout_params['readout_has_bias'] self.n_jobs = ncpu self.score_fun = det_coeff self.coef_ = None # Standard ridge regression if self.method == 'ridge': self.model = linear_model.Ridge(alpha=self.alpha, fit_intercept=self.has_bias) # Standard lasso regression elif self.method == 'lasso': self.model = linear_model.Lasso(alpha=self.alpha, fit_intercept=self.has_bias) # Standard linear regression elif self.method == 'linreg': self.model = linear_model.LinearRegression(fit_intercept=self.has_bias, n_jobs=self.n_jobs) # Standard logistic regression elif self.method == 'logreg': self.score_fun = conf_mat self.model = linear_model.LogisticRegression(fit_intercept=self.has_bias, n_jobs=self.n_jobs) # Removed in favor of Scikit-learn that handles the intercept much nicer. ## direct least squares (scipy) # elif self.method == 'direct': # self.model = lstsq else: raise ValueError(f'Fitting method {self.method} is invalid.')
[docs] def fit(self, X, y): """Fit the model using features X and target y.""" self.model.fit(X.T, y.T)
[docs] def predict(self, X): """Predict using fitted model on features X.""" return self.model.predict(X.T).T
[docs] def score(self, yhat, y): """Compute score between predictions yhat and true values y.""" return self.score_fun(yhat, y)
def __get__(self): return self.model @property def coeff(self): """Get the model coefficients.""" return self.model.coef_
[docs] def det_coeff(yhat, y, eps=1e-9): """Calculate the coefficient of determination (R-squared) score. Handles both single and multi-dimensional targets, with built-in outlier handling. For multi-dimensional data, computes R-squared separately for each dimension. Relevant from regression problems. Parameters ---------- yhat : array-like Predicted values y : array-like True values eps : float, optional Small constant to avoid division by zero, default 1e-9 Returns ------- float or array-like R-squared score(s). Returns array for multi-dimensional targets. """ if y.ndim==2: v = np.sum((y - y.mean(axis=1, keepdims=True))**2, axis=1) + eps u = np.sum((y - yhat)**2, axis=1) return 1 - u/v else: return r2_score(y, yhat)
[docs] def conf_mat(yhat, y): """Compute confusion matrix between predicted and true labels. Relevant from classification problems. Parameters ---------- yhat : array-like Predicted values y : array-like True values Returns ------- 2d array Confusion matrix """ return confusion_matrix(y, yhat)
# def fit(X, y, method='direct', n_jobs=-1, **kwargs): # """Fit model parameters using specified method. # Args: # X (ndarray): Input features of shape (n_features, n_samples) # y (ndarray): Target values of shape (n_targets, n_samples) # method (str, optional): Fitting method. One of: 'direct', 'sk_ridge', # 'sk_lasso', 'sk_linreg'. Defaults to 'direct'. # n_jobs (int, optional): Number of jobs for parallel computation. # Defaults to -1. # **kwargs: Additional parameters passed to specific methods: # - alpha: Regularization strength for ridge/lasso # Returns: # ndarray: Fitted model parameters # Raises: # ValueError: If method is not one of the supported options # """ # if method == 'direct': # # Direct least squares solution # w, res, rank, s = lstsq(X.T, y.T) # return w.T # elif method == 'sk_ridge': # # Ridge regression with cross-validation # alpha = kwargs.get('alpha', 1e-6) # reg = linear_model.Ridge(alpha=alpha, fit_intercept=False) # reg.fit(X.T, y.T) # return reg.coef_ # elif method == 'sk_lasso': # # LASSO regression for sparsity # alpha = kwargs.get('alpha', 1e-6) # reg = linear_model.Lasso(alpha=alpha, fit_intercept=False) # reg.fit(X.T, y.T) # return reg.coef_ # elif method == 'sk_linreg': # # Standard linear regression # reg = linear_model.LinearRegression(fit_intercept=False, n_jobs=n_jobs) # reg.fit(X.T, y.T) # return reg.coef_ # elif method == 'sk_logreg': # # Standard linear regression # reg = linear_model.LogisticRegression(fit_intercept=False, # max_iter=1000, n_jobs=n_jobs) # reg.fit(X.T, y.T) # return reg.coef_ # else: # raise ValueError(f'Fitting method {method} is invalid.')