"""
This module defines the linear readout used to decode the target functions from the networks' state. This module is based on scikit-learn's linear model module and can be extended to other (possibly non-linear) regression methods as needed.
"""
import os
osjoin = os.path.join
from pdb import set_trace
import numpy as np
from sklearn import linear_model
from sklearn.metrics import confusion_matrix, r2_score
from hetero.utils import timer
[docs]
class Decoder(object):
"""
A decoder class for fitting and predicting using various regression methods.
Parameters
----------
readout_params : dict
Dictionary containing readout parameters:
- 'readout_method': str, regression method ('ridge', 'lasso', 'linreg', 'logreg')
- 'readout_has_bias': bool, whether to include bias term
- 'readout_has_bias': float, regularization parameter (for ridge and lasso)
ncpu : int, optional
Number of CPU cores to use, by default -1 (all cores)
.. warning:: The readout method must be compatible with the problem at hand. No sanity checks
are performed to assess that is the case or not.
"""
def __init__(self, readout_params, ncpu=-1, **kwargs):
self.method = readout_params['readout_method']
self.has_bias = readout_params['readout_has_bias']
self.alpha = readout_params['readout_has_bias']
self.n_jobs = ncpu
self.score_fun = det_coeff
self.coef_ = None
# Standard ridge regression
if self.method == 'ridge':
self.model = linear_model.Ridge(alpha=self.alpha, fit_intercept=self.has_bias)
# Standard lasso regression
elif self.method == 'lasso':
self.model = linear_model.Lasso(alpha=self.alpha, fit_intercept=self.has_bias)
# Standard linear regression
elif self.method == 'linreg':
self.model = linear_model.LinearRegression(fit_intercept=self.has_bias,
n_jobs=self.n_jobs)
# Standard logistic regression
elif self.method == 'logreg':
self.score_fun = conf_mat
self.model = linear_model.LogisticRegression(fit_intercept=self.has_bias,
n_jobs=self.n_jobs)
# Removed in favor of Scikit-learn that handles the intercept much nicer.
## direct least squares (scipy)
# elif self.method == 'direct':
# self.model = lstsq
else:
raise ValueError(f'Fitting method {self.method} is invalid.')
[docs]
def fit(self, X, y):
"""Fit the model using features X and target y."""
self.model.fit(X.T, y.T)
[docs]
def predict(self, X):
"""Predict using fitted model on features X."""
return self.model.predict(X.T).T
[docs]
def score(self, yhat, y):
"""Compute score between predictions yhat and true values y."""
return self.score_fun(yhat, y)
def __get__(self):
return self.model
@property
def coeff(self):
"""Get the model coefficients."""
return self.model.coef_
[docs]
def det_coeff(yhat, y, eps=1e-9):
"""Calculate the coefficient of determination (R-squared) score.
Handles both single and multi-dimensional targets, with built-in outlier handling.
For multi-dimensional data, computes R-squared separately for each dimension.
Relevant from regression problems.
Parameters
----------
yhat : array-like
Predicted values
y : array-like
True values
eps : float, optional
Small constant to avoid division by zero, default 1e-9
Returns
-------
float or array-like
R-squared score(s). Returns array for multi-dimensional targets.
"""
if y.ndim==2:
v = np.sum((y - y.mean(axis=1, keepdims=True))**2, axis=1) + eps
u = np.sum((y - yhat)**2, axis=1)
return 1 - u/v
else:
return r2_score(y, yhat)
[docs]
def conf_mat(yhat, y):
"""Compute confusion matrix between predicted and true labels. Relevant from classification
problems.
Parameters
----------
yhat : array-like
Predicted values
y : array-like
True values
Returns
-------
2d array
Confusion matrix
"""
return confusion_matrix(y, yhat)
# def fit(X, y, method='direct', n_jobs=-1, **kwargs):
# """Fit model parameters using specified method.
# Args:
# X (ndarray): Input features of shape (n_features, n_samples)
# y (ndarray): Target values of shape (n_targets, n_samples)
# method (str, optional): Fitting method. One of: 'direct', 'sk_ridge',
# 'sk_lasso', 'sk_linreg'. Defaults to 'direct'.
# n_jobs (int, optional): Number of jobs for parallel computation.
# Defaults to -1.
# **kwargs: Additional parameters passed to specific methods:
# - alpha: Regularization strength for ridge/lasso
# Returns:
# ndarray: Fitted model parameters
# Raises:
# ValueError: If method is not one of the supported options
# """
# if method == 'direct':
# # Direct least squares solution
# w, res, rank, s = lstsq(X.T, y.T)
# return w.T
# elif method == 'sk_ridge':
# # Ridge regression with cross-validation
# alpha = kwargs.get('alpha', 1e-6)
# reg = linear_model.Ridge(alpha=alpha, fit_intercept=False)
# reg.fit(X.T, y.T)
# return reg.coef_
# elif method == 'sk_lasso':
# # LASSO regression for sparsity
# alpha = kwargs.get('alpha', 1e-6)
# reg = linear_model.Lasso(alpha=alpha, fit_intercept=False)
# reg.fit(X.T, y.T)
# return reg.coef_
# elif method == 'sk_linreg':
# # Standard linear regression
# reg = linear_model.LinearRegression(fit_intercept=False, n_jobs=n_jobs)
# reg.fit(X.T, y.T)
# return reg.coef_
# elif method == 'sk_logreg':
# # Standard linear regression
# reg = linear_model.LogisticRegression(fit_intercept=False,
# max_iter=1000, n_jobs=n_jobs)
# reg.fit(X.T, y.T)
# return reg.coef_
# else:
# raise ValueError(f'Fitting method {method} is invalid.')