ladder.scripts.workflows.CrossConditionWorkflow#
- class ladder.scripts.workflows.CrossConditionWorkflow(anndata, verbose=False, random_seed=None)#
Cross-condition workflow for training with a non-linear decoder.
Inherits
BaseWorkflowand adds functionalities desired from running a cross-conditional model for more precise reconstructions and transfers.- Parameters:
- evaluate_transfer(source, target, cell_type=None, n_iter=10)#
Evaluates the quality of transfers with generative metrics.
Attributes table#
Methods table#
|
Evaluates the quality of reconstructions with generative metrics. |
|
Evaluates the separability of latent embeddings for conditions. |
|
Evaluates the quality of transfers with generative metrics. |
|
Loads parameters for the attached model. |
|
Simple plotter for loss functions. |
|
Prepares the model to be run. |
|
Runs the model on the attached data object. |
|
Saves the attached model. |
Places the calculated cell embeddings from the trained model under the corresponding |
Attributes#
- CrossConditionWorkflow.METRICS_REG = {'chamfer': 'Chamfer Discrepancy', 'corr': 'Profile Correlation', 'rmse': 'RMSE', 'swd': '2-Sliced Wasserstein'}#
- CrossConditionWorkflow.OPT_CLASS1 = ['SCVI', 'SCANVI']#
- CrossConditionWorkflow.OPT_CLASS2 = ['Patches']#
- CrossConditionWorkflow.OPT_DEFAULTS = {'betas': (0.9, 0.999), 'eps': 0.01, 'gamma': 1, 'lr': 0.001, 'milestones': [10000000000.0]}#
- CrossConditionWorkflow.OPT_LIST = ['optimizer', 'optim_args', 'gamma', 'milestones', 'lr', 'eps', 'betas']#
- CrossConditionWorkflow.SEP_METRICS_REG = {'calc_asw': 'Average Silhouette Width', 'kmeans_ari': 'K-Means ARI', 'kmeans_nmi': 'K-Means NMI', 'knn_error': 'kNN Classifier Accuracy'}#
Methods#
- CrossConditionWorkflow.evaluate_reconstruction(subset=None, cell_type=None, n_iter=5)#
Evaluates the quality of reconstructions with generative metrics.
- Parameters:
subset (
str, optional) – Key fromlevelsto subset cells for a specific condition before evaluating reconstruction.cell_type (
str, optional) – Requirescell_type_label_keyto be defined as attribute. Subset cells to a single type before evaluating reconstruction.n_iter (
int, default: 5) – Number of times to repeat the generative process.
- CrossConditionWorkflow.evaluate_separability(factor=None)#
Evaluates the separability of latent embeddings for conditions.
- Parameters:
factor (
str, optional) – Item listed inBaseWorkflow.factors. If not provided, the metrics will be evaluated on the combinations of factors.
- CrossConditionWorkflow.evaluate_transfer(source, target, cell_type=None, n_iter=10)#
Evaluates the quality of transfers with generative metrics.
- Parameters:
source (
str) – Key fromBaseWorkflow.levelsto decide source condition.target (
str) – Key fromBaseWorkflow.levelsto decide target condition.cell_type (
str, optional) – RequiresBaseWorkflow.cell_type_label_keyto be defined as attribute. Subset cells to a single type before evaluating transfer.n_iter (
int, default: 10) – Number of times to repeat the generative process.
- CrossConditionWorkflow.load_model(params_load_path)#
Loads parameters for the attached model. Needs
prep_model()to be run first.- Parameters:
params_load_path (
str) – Path to find model parameters. Expects only the shared prefix, and not the trailing “_torch.pth” or “_pyro.pth”.
- CrossConditionWorkflow.plot_loss(save_loss_path=None)#
Simple plotter for loss functions.
- Parameters:
save_loss_path (
str, optional) – If provided, saves the figure to the specified location. Requires the full name with extensions (eg. fig.png).
- CrossConditionWorkflow.prep_model(factors, batch_key=None, cell_type_label_key=None, minibatch_size=128, model_type='Patches', model_args=None, optim_args=None)#
Prepares the model to be run.
The choice of model implicitly decides the kind of condition encodings to use, so there is no need to have a separate data preparation.
- Parameters:
batch_key (
str, optional) – Defines the workflow to be used. Affects model structure. Can later be accessed with same named attribute.cell_type_label_key (
str, optional) – Optional cell type labels inobs, required if cell-type specific evaluation is desired.minibatch_size (
int, default: 128) – Size of the minibatch to be provided during training.model_type (
Literal["SCVI", "SCANVI", "Patches"], default: “Patches”) – Specifies the model attached to the current workflow.model_args (
dict) – Model arguments passed to low-level model constructor. Seemodelsfor details.optim_args (
dict) – Optimizer arguments passed to low-level trainer. Seetrainingfor details.
- CrossConditionWorkflow.run_model(max_epochs=1500, convergence_threshold=0.0001, convergence_window=100, classifier_warmup=0, classifier_aggression=0, params_save_path=None)#
Runs the model on the attached data object.
- Parameters:
max_epochs (
int, default: 1500) – Maximum number of epochs to run.convergence_threshold (
float, default: 1e-3) – Minimum improvement required to continue training.convergence_window (
int, default: 30) – Number of epochs to wait until a new minimum is attained.classifier_warmup (
int, default: 0) – Number of epochs to run the classifier before running the entire model.classifier_aggression (
int, default: 0) – Number of epochs the classifier takes independently between jointly trained epochs. Used for Patches.params_save_path (
str, optional) – If provided, saves the model to the specified path.