ladder.scripts.workflows.InterpretableWorkflow#
- class ladder.scripts.workflows.InterpretableWorkflow(anndata, verbose=False, random_seed=None)#
Interpretable workflow for training with a linear decoder.
Inherits
BaseWorkflowand adds functionalities desired from running the interpretable models with linear decoders.- Parameters:
Attributes table#
Methods table#
|
Evaluates the quality of reconstructions with generative metrics. |
|
Evaluates the separability of latent embeddings for conditions. |
Writes non-conditional gene loadings to |
|
Writes attribute specific gene loadings to |
|
|
Loads parameters for the attached model. |
|
Simple plotter for loss functions. |
|
Prepares the model to be run. |
|
Runs the model on the attached data object. |
|
Saves the attached model. |
Places the calculated cell embeddings from the trained model under the corresponding |
Attributes#
- InterpretableWorkflow.METRICS_REG = {'chamfer': 'Chamfer Discrepancy', 'corr': 'Profile Correlation', 'rmse': 'RMSE', 'swd': '2-Sliced Wasserstein'}#
- InterpretableWorkflow.OPT_CLASS1 = ['SCVI', 'SCANVI']#
- InterpretableWorkflow.OPT_CLASS2 = ['Patches']#
- InterpretableWorkflow.OPT_DEFAULTS = {'betas': (0.9, 0.999), 'eps': 0.01, 'gamma': 1, 'lr': 0.001, 'milestones': [10000000000.0]}#
- InterpretableWorkflow.OPT_LIST = ['optimizer', 'optim_args', 'gamma', 'milestones', 'lr', 'eps', 'betas']#
- InterpretableWorkflow.SEP_METRICS_REG = {'calc_asw': 'Average Silhouette Width', 'kmeans_ari': 'K-Means ARI', 'kmeans_nmi': 'K-Means NMI', 'knn_error': 'kNN Classifier Accuracy'}#
Methods#
- InterpretableWorkflow.evaluate_reconstruction(subset=None, cell_type=None, n_iter=5)#
Evaluates the quality of reconstructions with generative metrics.
- Parameters:
subset (
str, optional) – Key fromlevelsto subset cells for a specific condition before evaluating reconstruction.cell_type (
str, optional) – Requirescell_type_label_keyto be defined as attribute. Subset cells to a single type before evaluating reconstruction.n_iter (
int, default: 5) – Number of times to repeat the generative process.
- InterpretableWorkflow.evaluate_separability(factor=None)#
Evaluates the separability of latent embeddings for conditions.
- Parameters:
factor (
str, optional) – Item listed inBaseWorkflow.factors. If not provided, the metrics will be evaluated on the combinations of factors.
- InterpretableWorkflow.get_common_loadings()#
Writes non-conditional gene loadings to
var.Can be used with all models.
- InterpretableWorkflow.get_conditional_loadings()#
Writes attribute specific gene loadings to
var.Only to be used with Patches, as the other models do not offer an attribute-specific way to learn coefficients.
- InterpretableWorkflow.load_model(params_load_path)#
Loads parameters for the attached model. Needs
prep_model()to be run first.- Parameters:
params_load_path (
str) – Path to find model parameters. Expects only the shared prefix, and not the trailing “_torch.pth” or “_pyro.pth”.
- InterpretableWorkflow.plot_loss(save_loss_path=None)#
Simple plotter for loss functions.
- Parameters:
save_loss_path (
str, optional) – If provided, saves the figure to the specified location. Requires the full name with extensions (eg. fig.png).
- InterpretableWorkflow.prep_model(factors, batch_key=None, cell_type_label_key=None, minibatch_size=128, model_type='Patches', model_args=None, optim_args=None)#
Prepares the model to be run.
The choice of model implicitly decides the kind of condition encodings to use, so there is no need to have a separate data preparation.
- Parameters:
batch_key (
str, optional) – Defines the workflow to be used. Affects model structure. Can later be accessed with same named attribute.cell_type_label_key (
str, optional) – Optional cell type labels inobs, required if cell-type specific evaluation is desired.minibatch_size (
int, default: 128) – Size of the minibatch to be provided during training.model_type (
Literal["SCVI", "SCANVI", "Patches"], default: “Patches”) – Specifies the model attached to the current workflow.model_args (
dict) – Model arguments passed to low-level model constructor. Seemodelsfor details.optim_args (
dict) – Optimizer arguments passed to low-level trainer. Seetrainingfor details.
- InterpretableWorkflow.run_model(max_epochs=1500, convergence_threshold=0.0001, convergence_window=100, classifier_warmup=0, classifier_aggression=0, params_save_path=None)#
Runs the model on the attached data object.
- Parameters:
max_epochs (
int, default: 1500) – Maximum number of epochs to run.convergence_threshold (
float, default: 1e-3) – Minimum improvement required to continue training.convergence_window (
int, default: 30) – Number of epochs to wait until a new minimum is attained.classifier_warmup (
int, default: 0) – Number of epochs to run the classifier before running the entire model.classifier_aggression (
int, default: 0) – Number of epochs the classifier takes independently between jointly trained epochs. Used for Patches.params_save_path (
str, optional) – If provided, saves the model to the specified path.