ladder.scripts.workflows.CrossConditionWorkflow#

class ladder.scripts.workflows.CrossConditionWorkflow(anndata, verbose=False, random_seed=None)#

Cross-condition workflow for training with a non-linear decoder.

Inherits BaseWorkflow and adds functionalities desired from running a cross-conditional model for more precise reconstructions and transfers.

Parameters:
  • anndata (AnnData) – The dataset object to be used throughout the analyses.

  • verbose (bool, default: False) – If True, prints progress messages for various methods within the module.

  • random_seed (int, optional) – If given, seeds the internal modules with the value.

evaluate_transfer(source, target, cell_type=None, n_iter=10)#

Evaluates the quality of transfers with generative metrics.

Attributes table#

Methods table#

evaluate_reconstruction([subset, cell_type, ...])

Evaluates the quality of reconstructions with generative metrics.

evaluate_separability([factor])

Evaluates the separability of latent embeddings for conditions.

evaluate_transfer(source, target[, ...])

Evaluates the quality of transfers with generative metrics.

load_model(params_load_path)

Loads parameters for the attached model.

plot_loss([save_loss_path])

Simple plotter for loss functions.

prep_model(factors[, batch_key, ...])

Prepares the model to be run.

run_model([max_epochs, ...])

Runs the model on the attached data object.

save_model(params_save_path)

Saves the attached model.

write_embeddings()

Places the calculated cell embeddings from the trained model under the corresponding obsm field.

Attributes#

CrossConditionWorkflow.METRICS_REG = {'chamfer': 'Chamfer Discrepancy', 'corr': 'Profile Correlation', 'rmse': 'RMSE', 'swd': '2-Sliced Wasserstein'}#
CrossConditionWorkflow.OPT_CLASS1 = ['SCVI', 'SCANVI']#
CrossConditionWorkflow.OPT_CLASS2 = ['Patches']#
CrossConditionWorkflow.OPT_DEFAULTS = {'betas': (0.9, 0.999), 'eps': 0.01, 'gamma': 1, 'lr': 0.001, 'milestones': [10000000000.0]}#
CrossConditionWorkflow.OPT_LIST = ['optimizer', 'optim_args', 'gamma', 'milestones', 'lr', 'eps', 'betas']#
CrossConditionWorkflow.SEP_METRICS_REG = {'calc_asw': 'Average Silhouette Width', 'kmeans_ari': 'K-Means ARI', 'kmeans_nmi': 'K-Means NMI', 'knn_error': 'kNN Classifier Accuracy'}#

Methods#

CrossConditionWorkflow.evaluate_reconstruction(subset=None, cell_type=None, n_iter=5)#

Evaluates the quality of reconstructions with generative metrics.

Parameters:
  • subset (str, optional) – Key from levels to subset cells for a specific condition before evaluating reconstruction.

  • cell_type (str, optional) – Requires cell_type_label_key to be defined as attribute. Subset cells to a single type before evaluating reconstruction.

  • n_iter (int, default: 5) – Number of times to repeat the generative process.

CrossConditionWorkflow.evaluate_separability(factor=None)#

Evaluates the separability of latent embeddings for conditions.

Parameters:

factor (str, optional) – Item listed in BaseWorkflow.factors. If not provided, the metrics will be evaluated on the combinations of factors.

CrossConditionWorkflow.evaluate_transfer(source, target, cell_type=None, n_iter=10)#

Evaluates the quality of transfers with generative metrics.

Parameters:
CrossConditionWorkflow.load_model(params_load_path)#

Loads parameters for the attached model. Needs prep_model() to be run first.

Parameters:

params_load_path (str) – Path to find model parameters. Expects only the shared prefix, and not the trailing “_torch.pth” or “_pyro.pth”.

CrossConditionWorkflow.plot_loss(save_loss_path=None)#

Simple plotter for loss functions.

Parameters:

save_loss_path (str, optional) – If provided, saves the figure to the specified location. Requires the full name with extensions (eg. fig.png).

CrossConditionWorkflow.prep_model(factors, batch_key=None, cell_type_label_key=None, minibatch_size=128, model_type='Patches', model_args=None, optim_args=None)#

Prepares the model to be run.

The choice of model implicitly decides the kind of condition encodings to use, so there is no need to have a separate data preparation.

Parameters:
  • factors (list) – Factors from obs to register to the model.

  • batch_key (str, optional) – Defines the workflow to be used. Affects model structure. Can later be accessed with same named attribute.

  • cell_type_label_key (str, optional) – Optional cell type labels in obs, required if cell-type specific evaluation is desired.

  • minibatch_size (int, default: 128) – Size of the minibatch to be provided during training.

  • model_type (Literal["SCVI", "SCANVI", "Patches"], default: “Patches”) – Specifies the model attached to the current workflow.

  • model_args (dict) – Model arguments passed to low-level model constructor. See models for details.

  • optim_args (dict) – Optimizer arguments passed to low-level trainer. See training for details.

CrossConditionWorkflow.run_model(max_epochs=1500, convergence_threshold=0.0001, convergence_window=100, classifier_warmup=0, classifier_aggression=0, params_save_path=None)#

Runs the model on the attached data object.

Parameters:
  • max_epochs (int, default: 1500) – Maximum number of epochs to run.

  • convergence_threshold (float, default: 1e-3) – Minimum improvement required to continue training.

  • convergence_window (int, default: 30) – Number of epochs to wait until a new minimum is attained.

  • classifier_warmup (int, default: 0) – Number of epochs to run the classifier before running the entire model.

  • classifier_aggression (int, default: 0) – Number of epochs the classifier takes independently between jointly trained epochs. Used for Patches.

  • params_save_path (str, optional) – If provided, saves the model to the specified path.

CrossConditionWorkflow.save_model(params_save_path)#

Saves the attached model.

Parameters:

params_save_path (str) – Path to save model parameters. Expects only the name without extensions.

CrossConditionWorkflow.write_embeddings()#

Places the calculated cell embeddings from the trained model under the corresponding obsm field.

Each model has a separate name for their respective latent, so that more than a single workflow running on the same object instance does not overwrite info.